Remove Tensorflow modules

This commit removes the following Tensorflow 1.0 related modules:

- spring-tensorflow-common
- spring-image-recognition-function
- spring-object-detection-function
- spring-semantic-segmentation-function

These may be resurrected at a later date but re-implemented with
either DeepJavaLibrary or TensorFlow 2.0.

* Remove Tensorflow modules from README.adoc

* Remove protobuf - was only used by Tensorflow
This commit is contained in:
Chris Bono
2024-01-16 16:02:49 -06:00
committed by GitHub
parent 5a3af2fb97
commit 0f65110e89
75 changed files with 22 additions and 8669 deletions

View File

@@ -26,78 +26,70 @@ This functions catalog is also a foundation for https://spring.io/projects/sprin
|link:supplier/spring-ftp-supplier/README.adoc[FTP]
|link:function/spring-header-enricher-function/README.adoc[Header-Enricher]
|link:consumer/spring-elasticsearch-consumer/README.adoc[Elasticsearch]
|
|link:function/spring-header-filter-function/README.adoc[Header-Filter]
|link:consumer/spring-file-consumer/README.adoc[File]
|link:supplier/spring-http-supplier/README.adoc[HTTP]
|link:function/spring-http-request-function/README.adoc[HTTP Request]
|link:function/spring-header-filter-function/README.adoc[Header-Filter]
|link:consumer/spring-ftp-consumer/README.adoc[FTP]
|link:supplier/spring-jdbc-supplier/README.adoc[JDBC]
|link:function/spring-image-recognition-function/README.adoc[Image Recognition(Tensorflow)]
|
|link:supplier/spring-jms-supplier/README.adoc[JMS]
|link:function/spring-object-detection-function/README.adoc[Object Detection(Tensorflow)]
|link:function/spring-http-request-function/README.adoc[HTTP Request]
|link:consumer/spring-jdbc-consumer/README.adoc[JDBC]
|link:supplier/spring-mail-supplier/README.adoc[Mail]
|link:function/spring-semantic-segmentation-function/README.adoc[Semantic Segmentation(Tensorflow)]
|link:supplier/spring-jms-supplier/README.adoc[JMS]
|link:function/spring-spel-function/README.adoc[SpEL]
|link:consumer/spring-log-consumer/README.adoc[Log]
|link:supplier/spring-mongodb-supplier/README.adoc[MongoDB]
|link:function/spring-spel-function/README.adoc[SpEL]
|link:supplier/spring-mail-supplier/README.adoc[Mail]
|link:function/spring-splitter-function/README.adoc[Splitter]
|link:consumer/spring-mongodb-consumer/README.adoc[MongoDB]
|link:supplier/spring-mqtt-supplier/README.adoc[MQTT]
|link:function/spring-splitter-function/README.adoc[Splitter]
|link:supplier/spring-mongodb-supplier/README.adoc[MongoDB]
|link:function/spring-task-launch-request-function/README.adoc[Task Launch Request]
|link:consumer/spring-mqtt-consumer/README.adoc[MQTT]
|link:supplier/spring-rabbit-supplier/README.adoc[RabbitMQ]
|link:function/spring-task-launch-request-function/README.adoc[Task Launch Request]
|link:supplier/spring-mqtt-supplier/README.adoc[MQTT]
|link:function/spring-twitter-function/README.adoc[Twitter]
|link:consumer/spring-rabbit-consumer/README.adoc[RabbitMQ]
|link:supplier/spring-s3-supplier/README.adoc[AWS S3]
|
|link:consumer/spring-redis-consumer/README.adoc[Redis]
|link:supplier/spring-sftp-supplier/README.adoc[SFTP]
|link:supplier/spring-rabbit-supplier/README.adoc[RabbitMQ]
|
|link:consumer/spring-rsocket-consumer/README.adoc[RSocket]
|link:supplier/spring-syslog-supplier/README.adoc[Syslog]
|link:supplier/spring-sftp-supplier/README.adoc[SFTP]
|
|link:consumer/spring-s3-consumer/README.adoc[AWS S3]
|link:supplier/spring-tcp-supplier/README.adoc[TCP]
|link:supplier/spring-syslog-supplier/README.adoc[Syslog]
|
|link:consumer/spring-sftp-consumer/README.adoc[SFTP]
|link:supplier/spring-time-supplier/README.adoc[Time]
|link:supplier/spring-tcp-supplier/README.adoc[TCP]
|
|link:consumer/spring-tcp-consumer/README.adoc[TCP]
|link:supplier/spring-twitter-supplier/README.adoc[Twitter]
|link:function/spring-twitter-function/README.adoc[Twitter]
|link:supplier/spring-time-supplier/README.adoc[Time]
|
|link:consumer/spring-twitter-consumer/README.adoc[Twitter]
|link:supplier/spring-websocket-supplier/README.adoc[Websocket]
|link:supplier/spring-twitter-supplier/README.adoc[Twitter]
|
|link:consumer/spring-websocket-consumer/README.adoc[Websocket]
|
|link:supplier/spring-websocket-supplier/README.adoc[Websocket]
|
|link:consumer/spring-wavefront-consumer/README.adoc[Wavefront]
|link:supplier/spring-xmpp-supplier/README.adoc[XMPP]
|
|link:consumer/spring-xmpp-consumer/README.adoc[XMPP]
|===
== Guidelines
See link:CONTRIBUTING.adoc[Contributor Guidelines].
See link:CONTRIBUTING.adoc[Contributor Guidelines].

View File

@@ -8,7 +8,6 @@ plugins {
id 'io.spring.dependency-management' version '1.1.4'
id 'io.spring.javaformat' version "${javaFormatVersion}"
id 'com.github.spotbugs' version '6.0.6'
id 'com.google.protobuf' version '0.9.4' apply false
id 'org.ajoberstar.grgit' version '5.2.1'
}
@@ -149,10 +148,6 @@ configure(javaProjects) { subproject ->
checkstyle("io.spring.javaformat:spring-javaformat-checkstyle:${javaFormatVersion}")
}
tasks.named('checkFormatMain') {
exclude 'org/springframework/cloud/fn/object/detection/protos'
}
[compileJava, compileTestJava]*.options*.compilerArgs = ['-Xlint:all,-options,-processing', '-parameters']
test {

View File

@@ -1,344 +0,0 @@
:images-asciidoc: https://raw.githubusercontent.com/spring-cloud/stream-applications/master/functions/common/tensorflow-common/src/main/resources/images/
= Programming Model for TensorFlow Inference
Programming model builds on the https://docs.oracle.com/javase/8/docs/api/java/util/function/package-summary.html[Java Function API], the `TF Java Ops API` and few basic data structure that together help to unify and streamline the building of TensorFlow inference pipelines.
Quick Start: just add the following dependency:
[source,XML]
----
<dependency>
<groupId>org.springframework.cloud.fn</groupId>
<artifactId>tensorflow-common</artifactId>
<version>${revision}</version>
</dependency>
----
== Programming Model
Implementing a real-time TensorFlow Inference in Java, typically leverages the TF Java API for loading and scoring the pre-trained models. But the logic used to convert the upstream data into model input Tensors (e.g. pre-processor) and in turn to convert the inferred Tensors back into application data (e.g. post-processor) is commonly implemented in plain Java:
image::{images-asciidoc}/programming_model.png[TF Architecture, scaledwidth="70%"]
The Pre and Post processing steps could become very complex (check https://github.com/ildoonet/tf-pose-estimation[Pose Estimation] or https://github.com/davidsandberg/facenet[Face Recognition]) and computationally intensive. E.g. tons of math and image operations that are better fit for optimized TF utilities rather the plain Java math or AWT/2D/Canvas such.
Additionally, the unnecessary shuffling of data between the JVM and the native TF impacts the overall performance of the flow. The issue is apparent when multiple pre-trained TF models are combined (composed) or the same TF models are evaluated iteratively. In those cases the processed data is repeatedly being moved between the JVM Heap and the TF native memory.
The `Java Ops API` exposes the native https://www.tensorflow.org/versions/r1.9/api_docs/cc?hl=en[TF C++ Core API] offering comprehensive, native math, image, io and alike utilities. Later is useful for implementing the computational intensive logic by using the same TF tools and infrastructure used for running the pre-trained models.
Proposed programming model builds on the https://docs.oracle.com/javase/8/docs/api/java/util/function/package-summary.html[Java Function API], the Java Ops API and a basic data structure that together help to unify and streamline the building of TensorFlow inference pipelines. While focused on model-inference, this programming model is likely to be useful for building model-training pipelines as well.
The programming model leverages the following functional definition:
[source,Java]
----
Function<Map<String, Tensor<?>>, Map<String, Tensor<?>>>
----
and the corresponding method expression
[source,Java]
----
Map<String, Tensor<?>> apply( Map<String, Tensor<?>> feeds)
----
This function receives a map of named Tensors as an input and in turn returns a map of named Tensors. Because the input and the output formats are equivalent, functions with this signature can be reused and composed into larger, complex functions.
The names used in the inputs, and the outputs maps are strings of the form `operation_name: output_index` (output_index defaults to 0). The names must match the indexed operations inside the underlying TF graph. +
The https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Tensor[Tensor] class is a reference to the data used natively in the TF engine. The referenced data is not moved to JVM Heap unless explicitly materialized with `Tensor.copyTo()`. Exchanging Tensor references between functions prevents unnecessarily copying the data to the JVM heap. +
Proposed data structure fits well with existing https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Runner[Session.Runner API], which accepts https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Runner.html#feed(java.lang.String,%20org.tensorflow.Tensor%3C?%3E)[indexed operations] as an input feed and returns list of tensors predefined by https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Runner.html#fetch(java.lang.String)[fetch indexed operations].
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunner.java[GraphRunner] and the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphDefinition.java[GraphDefinition] are the core abstractions used to define, load and inference TensorFlow models. https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunner.java[GraphRunner] implements the Function (e.g. Fn<Map<S,T>, Map<S,T>>) definition and uses the TF Java API to run the underlying TF graph. The input Tensor map is fed to the Session Runner. After the graph is evaluated, a list of predefined fetch names is used to retrieve selected Tensors from the result as a named Tensor map. The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunner.java#L70[withGraphDefinition(GraphDefinition)] method defines a new or loads a pre-trained TF graph, while the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunner.java#L84[withSavedModel(path)] method helps to load a Tensorflow SavedModel.
The GraphDefinition argument is a functional interface and can therefore be used as the assignment target for a lambda expression or method reference.
Following snippets illustrates how to use the `withGraphDefinition` to define a new TF Graph that computes the `y1 = x1 * 2` expression:
[source,Java]
----
myGraph = new GraphRunner("x1", "y1")
.withGraphDefinition( tf ->
tf.withName("y1").math.mul(
tf.withName("x1").placeholder(Integer.class),
tf.constant(2)));
----
The x1 and y1 constructor arguments define the input and output Tensor names (technically indexed operation names) used to feed in and fetch out data to and from the defined model.
The GraphRunner https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/AbstractGraphRunner.java#L65[apply] method helps evaluate/inference the so defined graph:
[source,Java]
----
Map<String, Tensor<?>> input = Collections.singletonMap("x1", Tensor.create(666));
result = myGraph.apply(input);
----
Similarly, we can load a frozen/pre-trained model (https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet#pretrained-models[MobileNetV2] model in this case) from an archive using the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/ProtoBufGraphDefinition.java[ProtoBufGraphDefinitions] helper class.
[source,Java]
----
mibileNetV2 = new GraphRunner("input", "MobilenetV2/Predictions/Reshape_1")
.withGraphDefinition(
new ProtoBufGraphDefinition(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz#mobilenet_v2_1.4_224_frozen.pb",
cacheModel));
----
You can load archives from `http://`, `file://` or `classpath://` locations.
For loading a `SavedModel` use the `GraphRunner#withSavedModel method like this:
[source,Java]
----
ssdMibileNetV1Coco = new GraphRunner( Arrays.asList("image_tensor"),
Arrays.asList("detection_boxes", “detection_scores”, “detection_classes”))
.withSavedModel( ”./ssd_mobilenet_v1_coco_2017_11_17/saved_model”, "serve");
----
An example inference pipeline based on the proposed Functional Programming Model would look similar to this:
image::{images-asciidoc}/tf_pipeline.png[TF pipelinne, scaledwidth="70%"]
Every GraphRunner instance in the pipeline uses either an in-place defined, or a pre-trained TF graph.
In practice, it would still be required to implement some input and output adapters for the logic that cannot be (or are not feasible to be) implemented with the native Java Ops API. But you have the freedom to choose what part of the processing logic to run as natively (e.g. Java Ops API) code and what is a plain Java.
Furthermore, we are not limited to GraphRunner but any custom https://docs.oracle.com/javase/8/docs/api/java/util/function/Function.html[Function]<Map<String, Tensor>, Map<String, Tensor>> implementations can be used in the processing pipelines. In fact the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/Functions.java[Functions] utilities use this approach.
When appropriate any custom Function, https://docs.oracle.com/javase/8/docs/api/java/util/function/Supplier.html[Supplier]<Map<String, Tensor>>, https://docs.oracle.com/javase/8/docs/api/java/util/function/Consumer.html[Consumer]<Map<String, Tensor>> or the rest of the https://docs.oracle.com/javase/8/docs/api/java/util/function/package-frame.html[java.util.function] classes and interfaces can be used.
For real time examples check the https://github.com/spring-cloud/stream-applications/tree/master/functions/function/image-recognition-function[image-recognition] and https://github.com/spring-cloud/stream-applications/tree/master/functions/function/semantic-segmentation-function[semantic-segmentation] implementations.
== Features
Following paragraphs discusses some features and techniques useful for composing graphs, memorizing and reusing intermediate Tensor values, managing the Tensor resources and so on.
=== Input and Output Contracts
The GraphRunner constructor expects two compulsory list fields: `feedNames` - list of names that the graph accepts as input Tensors and `fetchNames` - list of (Tensor) names that the graph would return.
[source,Java]
----
public GraphRunner(List<String> feedNames, List<String> fetchedNames)
----
Together those two lists define the input and output contract of the graph runner. +
The names used in the inputs, and the outputs maps are strings of the form `operation_name : output_index` (output_index defaults to 0). The names must match the indexed operations inside the underlying TF graph.
=== Composition
Because the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunner.java[GraphRunner] function signature uses the same type for input and output parameters, the https://docs.oracle.com/javase/8/docs/api/java/util/function/Function.html[Functional] interface allows us to compose multiple graphs https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunner.java[GraphRunner] functions into a larger composite function:
[source,Java]
----
composed-graph = graph1.andThen(graph2)....andThen(graphN)
----
For example let's take two simple graphs: `G1 (y1 = x1 * 2)` and `G2 (y2 = x2 + 20)`. The composed graph `G = G1.andThen(G2)` is equivalent to `y = (x * 2 ) + 20`.
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/test/java/org/springframework/cloud/fn/common/tensorflow/FunctionComposition.java[FunctionComposition example] demonstrates how this works:
[source,Java]
----
try (
GraphRunner graph1 = new GraphRunner("x1", "y1")
.withGraphDefinition(tf -> tf.withName("y1").math.mul(
tf.withName("x1").placeholder(Integer.class),
tf.constant(2)));
GraphRunner graph2 = new GraphRunner("x2", "y2")
.withGraphDefinition(tf -> tf.withName("y2").math.add(
tf.withName("x2").placeholder(Integer.class),
tf.constant(20)));
Tensor x = Tensor.create(10);
) {
Map<String, Tensor<?>> result =
graph1.andThen(graph2).apply(Collections.singletonMap("x", x));
System.out.println("Result is: " + result.get("y2").intValue()); // Result is: 40
}
----
Note that the GraphRunner https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/AbstractGraphRunner.java#L65[automatically binds] the singleton outputs (e.g fetch) with the singleton input (e.g. feeds). In the example above the GraphRunner automatically binds the `y1` tensor produced by `graph1` to the `x2` input placeholders expected by `graph2`.
==== Multiple inputs/outputs
When the composed graphs use multiple input and output parameters we need to explicitly bind the outputs from the upstream graph to the inputs of the downstream one.
For example lets Graph1 produces two outputs (e.g. fetchNames) y11 and y12 and Graph2 expects to inputs (e.g. feedNames) x21 and x22:
|===
|Graph1:|Graph2:
| y11 = x1 * 2 | y2 = x21 + x22
| y12 = x1 * 3 |
|===
The composed graph would look like this:
[source,Java]
----
Composed = Graph1.andThen( map: y11 -> x21 and y12 -> x22).andThen(Graph2)
----
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/Functions.java#L72[Functions#rename] utility helps to define the input/output mappings as illustrated in the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/test/java/org/springframework/cloud/fn/common/tensorflow/FunctionCompositionMultipleInputsOutputs.java[FunctionCompositionMultipleInputsOutputs] example:
[source,Java]
----
try (
GraphRunner graph1 = new GraphRunner(Arrays.asList("x1"), Arrays.asList("y11", "y12"))
.withGraphDefinition(tf -> {
Placeholder<Integer> x1 = tf.withName("x1").placeholder(Integer.class);
tf.withName("y11").math.mul(x1, tf.constant(2));
tf.withName("y12").math.mul(x1, tf.constant(3));
});
GraphRunner graph2 = new GraphRunner(Arrays.asList("x21", "x22"), Arrays.asList("y2"))
.withGraphDefinition(tf -> tf.withName("y2").math.add(
tf.withName("x21").placeholder(Integer.class),
tf.withName("x22").placeholder(Integer.class)));
Tensor x = Tensor.create(10);
) {
Map<String, Tensor<?>> result =
graph1
.andThen(
Functions.rename(
"y11", "x21",
"y12", "x22"
))
.andThen(graph2)
.apply(Collections.singletonMap("x", x));
System.out.println("Result is: " + result.get("y2").intValue()); // Result is: 50
}
----
The Functions#rename(String...mappings) takes an even number of string pairs, where every even parameter represents the from and to name to map. Eg. The y11 above is mapped into x21 and y12 is mapped into x22. +
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/AbstractGraphRunner.java#L129[GraphRunner#enableAutoBinding()] and https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/AbstractGraphRunner.java#L124[GraphRunner#disableAutoBinding()] allow altering the autobinding behavior enforcing mapping even of singleton input/output graphs.
=== Save and Close Obsolete Tensors
The Tensors used as inputs (feeds) and outputs (fetches) by the GraphRunners have to be released (e.g. closed) when not used anymore.
Because every sub-graph in a composite pipeline produces one or more <String, Tensor> pairs we need to track those references and close them.
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunnerMemory.java[GraphRunnerMemory] is a handy utility Function implementation that keeps track of all input Tensor parameters passed through. It is https://docs.oracle.com/javase/8/docs/api/java/lang/AutoCloseable.html[AutoClosable] and will release all tracked Tensors when closed.
The GraphRunnerMemory implements the same function signatures as the GraphRunner (e.g. Fun<Map<S,T>, Map<S,T>>) and therefore can participate in composite graph definitions:
[source,Java]
----
try ( memory = new GraphRunnerMemory() ) {
composed-graph =
Graph1..andThen(memory)
.andThen(Graph2).andThen(memory)
.andThen(GraphN).andThen(memory)
….
} // releases all Tensors returned by the GraphRunners
----
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/test/java/org/springframework/cloud/fn/common/tensorflow/ReleaseTensorParameters.java[ReleaseTensorParameters] example illustrates how to use the GraphRunnerMemory:
[source,Java]
----
try (
Tensor x = Tensor.create(input);
GraphRunnerMemory memory = new GraphRunnerMemory();
) {
Map<String, Tensor<?>> result =
this.graph1.andThen(memory)
.andThen(this.graph2).andThen(memory)
.apply(Collections.singletonMap("x", x));
return result.get("y2").intValue();
}
// At that point all intermediate Tensors used by the GraphRunners are closed.
----
Note: the GraphRunnerMemory has some other very useful applications that we will highlight in the next paragraph.
=== Enrich Graph Inputs
For particular graphs in the composite pipeline, we can add an additional input parameters that were not produced by the upstream graph.
With the help fo the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/Functions.java#L42[Functions#enrichWith(name, Tensor)] utility function we can inject the additional parameters in the graph composition.
In the following snippet we enrich the graph2s input with an additional parameter (newParam):
[source,Java]
----
try (
Tensor x = Tensor.create(input);
Tensor additionalTensor = Tensor.create(colorMap);
) {
Map<String, Tensor<?>> result =
graph1
.andThen(Functions.enrichWIth("newParam", additionalTensor)
.andThen(graph2)
.apply(Collections.singletonMap("x", x));
return result.get("y2").intValue();
}
----
The https://github.com/spring-cloud/stream-applications/blob/master/functions/function/semantic-segmentation-function/src/main/java/org/springframework/cloud/fn/semantic/segmentation/SemanticSegmentation.java#L150[SemanticSegmentation] implementation provides a real example how to enrich with parameters.
=== Enrich Inputs from Saved Tensors
We can combine the enricher approach with the https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/GraphRunnerMemory.java[GraphRunnerMemory]. This allows us to enrich some downstream Graphs with tensor parameters computed in some of the upstream Graphs. The `Functions#enrichFromMemory(memory, tensorName)` utility function can enrich a graph input parameter by extracting one stored in the memory.
For example lets construct the following graph compositions:
----
graph1: y1 = x1 * 10 +
graph2: y2 = y1 * 200 +
graph3: y3 = y2 + y1
----
[source,Java]
----
try (
Tensor x = Tensor.create(input);
GraphRunnerMemory memory = new GraphRunnerMemory();
) {
Map<String, Tensor<?>> result =
this.graph1.andThen(memory) // memorizes y1
.andThen(graph2).andThen(memory) // memorizes y2
.andThen(Functions.enrichFromMemory(memory, "y1")) // retrieve graph1s output y1 and adds it as an input for the next function.
.andThen(Functions.rename(
"y1", "x31", // renames the input y1 into x31
"y2", "x32" // renames the input y2 into x32
))
.andThen(graph3).andThen(memory)
.apply(Collections.singletonMap("x", x));
return result.get("y3").intValue();
}
----
=== Load Frozen Models from Remote Archives
The ProtoBufGraphDefinition extracts a pre-trained (frozen) Tensorflow model form a URI archive into byte array. It supports the `http(s)://`, `file://` and `classpath://` URI schemas. For this it uses the `ModelExtractor` and `CachedModelExtractor` utilities.
Models can be extracted either from raw files or form compressed archives. When extracted from an archive the model file name can optionally be provided as a URI fragment. For example for resource: `http://myarchive.tar.gz#model.pb`
the `myarchive.tar.gz` is traversed to uncompress and extract the model.pb file as a byte array. If the file name is not provided as URI fragment then the first file in the archive with extension .pb is extracted.
In addition, the CachedModelExtractor allows keeping a local copy (cache) of the model (protobuf) files extracted from the URI archive.
|===
|The https://github.com/spring-cloud/stream-applications/tree/master/functions/function/image-recognition-function[image-recognition] and https://github.com/spring-cloud/stream-applications/tree/master/functions/function/semantic-segmentation-function[semantic-segmentation] inference models implementations demonstrate the suggested programming model.
|===

View File

@@ -1,12 +0,0 @@
ext {
tensorflowVersion='1.15.0'
}
dependencies {
api "org.tensorflow:tensorflow:$tensorflowVersion"
api "org.tensorflow:proto:$tensorflowVersion"
api 'org.apache.commons:commons-compress:1.25.0'
api 'commons-io:commons-io:2.15.0'
api 'org.apache.commons:commons-lang3'
api 'org.pcollections:pcollections:4.0.1'
}

View File

@@ -1,144 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.commons.lang3.Validate;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
/**
* @author Christian Tzolov
*/
public abstract class AbstractGraphRunner implements Function<Map<String, Tensor<?>>, Map<String, Tensor<?>>> {
public abstract Session doGetSession();
/**
* Names expected in the named Tensor inside the input
* {@link AbstractGraphRunner#apply(Map)}. If the apply method will fail if the input
* map is missing some of the feedNames.
*/
private final List<String> feedNames;
/**
* Names expected {@link AbstractGraphRunner#apply(Map)} result map.
*/
private final List<String> fetchNames;
/**
* When set and the input takes a single feed, then the name of the input tensor is
* automatically mapped to the expected input name. E.g. no need to rename the input
* names explicitly.
*/
private boolean autoBinding;
public AbstractGraphRunner(String feedName, String fetchedName) {
this(Arrays.asList(feedName), Arrays.asList(fetchedName));
}
public AbstractGraphRunner(List<String> feedNames, List<String> fetchedNames) {
this.feedNames = feedNames;
this.fetchNames = fetchedNames;
this.autoBinding = feedNames.size() == 1;
}
@Override
public Map<String, Tensor<?>> apply(Map<String, Tensor<?>> feeds) {
if (!this.isAutoBinding() && !feeds.keySet().containsAll(this.feedNames)) {
throw new IllegalArgumentException("Applied feeds:" + feeds.keySet()
+ "\n, don't match the expected feeds contract:" + this.feedNames);
}
if (this.isAutoBinding() && (feeds.size() != 1)) {
throw new IllegalArgumentException(
"Feed auto-binding expects a " + "single feed tensors but found: " + feeds);
}
Session.Runner runner = this.doGetSession().runner();
// Feed in the input named tensors
for (Map.Entry<String, Tensor<?>> feedEntry : feeds.entrySet()) {
String feedName = (this.isAutoBinding()) ? this.feedNames.get(0) : feedEntry.getKey();
runner = runner.feed(feedName, feedEntry.getValue());
}
// Set the tensor name to be fetched after the evaluation
for (String fetchName : this.fetchNames) {
runner.fetch(fetchName);
}
// Evaluate the input
List<Tensor<?>> outputTensors = runner.run();
// Extract the output tensors
Map<String, Tensor<?>> outTensorMap = new HashMap<>();
for (int outputIndex = 0; outputIndex < this.fetchNames.size(); outputIndex++) {
outTensorMap.put(this.fetchNames.get(outputIndex), outputTensors.get(outputIndex));
}
return outTensorMap;
}
public List<String> getFeedNames() {
return this.feedNames;
}
public String getSingleFeedName() {
Validate.isTrue(feedNames.size() == 1, "Assumes a single feed input");
return this.feedNames.get(0);
}
public List<String> getFetchNames() {
return this.fetchNames;
}
public String getSingleFetchName() {
Validate.isTrue(this.fetchNames.size() == 1, "Assumes a single fetch output");
return this.fetchNames.get(0);
}
public boolean isAutoBinding() {
return this.autoBinding;
}
public AbstractGraphRunner disableAutoBinding() {
this.autoBinding = false;
return this;
}
public AbstractGraphRunner enableAutoBinding() {
if (this.getFeedNames().size() != 1) {
throw new IllegalArgumentException("Auto-binding is permitted for Graphs with single input feed, but "
+ " found: " + this.getFeedNames());
}
this.autoBinding = true;
return this;
}
@Override
public String toString() {
return String.format("(%s) -> (%s)", String.join(",", this.feedNames), String.join(",", this.fetchNames));
}
}

View File

@@ -1,64 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.op.Ops;
/**
* @author Christian Tzolov
*/
class AutoCloseableSession implements AutoCloseable {
private Session session;
/**
* Note: don't call this method inside the constructor.
*/
protected void init() {
Graph graph = this.doCreateGraph();
this.doGraphDefinition(Ops.create(graph));
this.session = new Session(graph);
}
protected Graph doCreateGraph() {
return new Graph();
}
protected void doGraphDefinition(Ops tf) {
}
protected Session getSession() {
if (this.session == null) {
init();
}
return this.session;
}
@Override
public void close() {
this.doClose();
if (this.session != null) {
this.session.close();
}
}
protected void doClose() {
}
}

View File

@@ -1,87 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.tensorflow.Tensor;
/**
* @author Christian Tzolov
*/
public final class Functions {
private Functions() {
}
/**
* On every function call enrich the input tensorMap with an addition (tensorName,
* tensor) pair.
* @param tensorName tensor key to use in the map
* @param tensor new Tensor to add to the map
* @return Returns a copy of the input tensorMap enriched with the provided
* (tensorName, tensor).
*/
public static Function<Map<String, Tensor<?>>, Map<String, Tensor<?>>> enrichWith(String tensorName,
Tensor<?> tensor) {
return tensorMap -> enrich(tensorMap, tensorName, tensor);
}
/**
* On function call retrieves a named tensor from the provided
* {@link GraphRunnerMemory} and uses it to enrich the input tensorMap.
* @param memory GraphRunnerMemory to retrieve the tensor from
* @param tensorName name of the tensor in GraphRunnerMemory to retrieve.
* @return Returns copy of the input tensorMap enriched with the tensor from the
* memory.
*/
public static Function<Map<String, Tensor<?>>, Map<String, Tensor<?>>> enrichFromMemory(GraphRunnerMemory memory,
String tensorName) {
return tensorMap -> enrich(tensorMap, tensorName, memory.getTensorMap().get(tensorName));
}
private static Map<String, Tensor<?>> enrich(Map<String, Tensor<?>> inputTensorMap, String key, Tensor<?> value) {
Map<String, Tensor<?>> newMap = new HashMap<>(inputTensorMap);
newMap.put(key, value);
return newMap;
}
/**
* Renames the tensor names in the incoming tensorMap with the providing mappings.
* @param mapping Pairs of From and To names. E.g. fromName1, toName1, fromName2,
* toName2, ... fromNameN, toNameN Must be an even number.
* @return Map that renames the input tensorMap entries according to the mapping
* provided
*/
public static Function<Map<String, Tensor<?>>, Map<String, Tensor<?>>> rename(String... mapping) {
Map<String, String> mappingMap = new HashMap<>();
for (int i = 0; i < mapping.length; i = i + 2) {
mappingMap.put(mapping[i], mapping[i + 1]);
}
return tensorMap -> tensorMap.entrySet()
.stream()
.filter(e -> mappingMap.containsKey(e.getKey()))
.collect(Collectors.toMap(kv -> mappingMap.get(kv.getKey()), kv -> kv.getValue()));
}
}

View File

@@ -1,29 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import org.tensorflow.op.Ops;
/**
* @author Christian Tzolov
*/
@FunctionalInterface
public interface GraphDefinition {
void defineGraph(Ops tf);
}

View File

@@ -1,107 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.Validate;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.op.Ops;
/**
* @author Christian Tzolov
*/
public class GraphRunner extends AbstractGraphRunner implements AutoCloseable {
private SavedModelBundle savedModelBundle;
private AutoCloseableSession autoCloseableSession;
public GraphRunner(List<String> feedNames, String fetchedName) {
super(feedNames, Arrays.asList(fetchedName));
}
public GraphRunner(String feedName, List<String> fetchedNames) {
super(Arrays.asList(feedName), fetchedNames);
}
public GraphRunner(String feedName, String fetchedName) {
super(feedName, fetchedName);
}
public GraphRunner(List<String> feedNames, List<String> fetchedNames) {
super(feedNames, fetchedNames);
}
@Override
public Session doGetSession() {
if (this.autoCloseableSession != null && this.savedModelBundle != null) {
throw new IllegalStateException("Either SavedModel or GraphDefinition can be set! But both are set!");
}
if (this.autoCloseableSession != null) {
return this.autoCloseableSession.getSession();
}
if (this.savedModelBundle != null) {
return this.savedModelBundle.session();
}
throw new IllegalStateException("Either SavedModel or GraphDefinition can be set! None found");
}
public GraphRunner withGraphDefinition(GraphDefinition graphDefinition) {
Validate.isTrue(this.savedModelBundle == null, "Either SavedModel or GraphDefinition can be set! "
+ "SavedModelBundle is found: " + this.savedModelBundle);
this.autoCloseableSession = new AutoCloseableSession() {
@Override
protected void doGraphDefinition(Ops tf) {
graphDefinition.defineGraph(tf);
}
};
return this;
}
public GraphRunner withSavedModel(String savedModelDir, String... tags) {
Validate.isTrue(this.autoCloseableSession == null, "Either SavedModel or GraphDefinition can be set! "
+ "AutoCloseableSession is found: " + this.autoCloseableSession);
this.savedModelBundle = SavedModelBundle.load(savedModelDir, tags);
return this;
}
@Override
public String toString() {
return String.format("(%s) -> (%s)", String.join(",", this.getFeedNames()),
String.join(",", this.getFetchNames()));
}
@Override
public void close() {
if (this.savedModelBundle != null) {
this.savedModelBundle.close();
}
if (this.autoCloseableSession != null) {
this.autoCloseableSession.close();
}
}
}

View File

@@ -1,52 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.pcollections.HashTreePMap;
import org.pcollections.PMap;
import org.tensorflow.Tensor;
import org.springframework.cloud.fn.common.tensorflow.util.AutoCloseables;
/**
* Keeps all tensorMap input parameters.
*/
public class GraphRunnerMemory implements Function<Map<String, Tensor<?>>, Map<String, Tensor<?>>>, AutoCloseable {
private AtomicReference<PMap<String, Tensor<?>>> tensorMap = new AtomicReference<>(HashTreePMap.empty());
public Map<String, Tensor<?>> getTensorMap() {
return tensorMap.get();
}
@Override
public Map<String, Tensor<?>> apply(Map<String, Tensor<?>> tensorMap) {
this.tensorMap.getAndUpdate(pmap -> pmap.plusAll(tensorMap));
return tensorMap;
}
@Override
public void close() {
AutoCloseables.all(this.tensorMap.get());
// this.tensorMap.get().clear();
}
}

View File

@@ -1,75 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Iterator;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.op.Ops;
import org.springframework.cloud.fn.common.tensorflow.util.CachedModelExtractor;
import org.springframework.cloud.fn.common.tensorflow.util.ModelExtractor;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
/**
* @author Christian Tzolov
*/
public class ProtoBufGraphDefinition implements GraphDefinition {
/**
* Location of the pre-trained model archive.
*/
private final Resource modelLocation;
/**
* If set true the pre-trained model is cached on the local file system.
*/
private final boolean cacheModel;
public ProtoBufGraphDefinition(String modelUri, boolean cacheModel) {
this(new DefaultResourceLoader().getResource(modelUri), cacheModel);
}
public ProtoBufGraphDefinition(Resource modelLocation, boolean cacheModel) {
this.modelLocation = modelLocation;
this.cacheModel = cacheModel;
}
@Override
public void defineGraph(Ops tf) {
// Extract the pre-trained model as byte array.
byte[] model = this.cacheModel ? new CachedModelExtractor().getModel(this.modelLocation)
: new ModelExtractor().getModel(this.modelLocation);
// Import the pre-trained model
((Graph) tf.scope().env()).importGraphDef(model);
// try {
// ((Graph) tf.scope().env()).importGraphDef(GraphDef.parseFrom(model));
// }
// catch (InvalidProtocolBufferException e) {
// throw new RuntimeException(e);
// }
Graph graph = ((Graph) tf.scope().env());
Iterator<Operation> ops = graph.operations();
while (ops.hasNext()) {
System.out.println(ops.next().name());
}
}
}

View File

@@ -1,762 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow.deprecated;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Font;
import java.awt.FontMetrics;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.RenderingHints;
import java.awt.Stroke;
import java.awt.geom.Rectangle2D;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import javax.imageio.ImageIO;
import org.apache.commons.io.IOUtils;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
/**
* Utility class used to provide some handy image manipulation functions. Among others it
* can provide contrast colors for image annotation labels and bounding boxes as well as
* functionality to draw later.
*
* @author Christian Tzolov
*/
public final class GraphicsUtils {
private GraphicsUtils() {
}
/**
* Default DEFAULT_FONT used in image label annotation.
*/
private static final Font DEFAULT_FONT = new Font("arial", Font.PLAIN, 12);
/**
* Bounding box default line thickness.
*/
private static final float LINE_THICKNESS = 2;
/**
* Color used when no multi-color is used.
*/
private static final Color AGNOSTIC_COLOR = new Color(167, 252, 0);
/**
* in labels text offset.
*/
public static final int TITLE_OFFSET = 3;
/**
* Predefined contrasting colors used when drawing multiple objects in the same image.
*/
public static final Color aliceblue = new Color(240, 248, 255); /* color */
/** Color. **/
public static final Color antiquewhite = new Color(250, 235, 215);
/** Color. **/
public static final Color aqua = new Color(0, 255, 255); // color
/** Color. **/
public static final Color aquamarine = new Color(127, 255, 212); // color
/** Color. **/
public static final Color azure = new Color(240, 255, 255); // color
/** Color. **/
public static final Color beige = new Color(245, 245, 220); // color
/** Color. **/
public static final Color bisque = new Color(255, 228, 196);
/** Color. **/
public static final Color black = new Color(0, 0, 0);
/** Color. **/
public static final Color blanchedalmond = new Color(255, 255, 205);
/** Color. **/
public static final Color blue = new Color(0, 0, 255);
/** Color. **/
public static final Color blueviolet = new Color(138, 43, 226);
/** Color. **/
public static final Color brown = new Color(165, 42, 42);
/** Color. **/
public static final Color burlywood = new Color(222, 184, 135);
/** Color. **/
public static final Color cadetblue = new Color(95, 158, 160);
/** Color. **/
public static final Color chartreuse = new Color(127, 255, 0);
/** Color. **/
public static final Color chocolate = new Color(210, 105, 30);
/** Color. **/
public static final Color coral = new Color(255, 127, 80);
/** Color. **/
public static final Color cornflowerblue = new Color(100, 149, 237);
/** Color. **/
public static final Color cornsilk = new Color(255, 248, 220);
/** Color. **/
public static final Color crimson = new Color(220, 20, 60);
/** Color. **/
public static final Color cyan = new Color(0, 255, 255);
/** Color. **/
public static final Color darkblue = new Color(0, 0, 139);
/** Color. **/
public static final Color darkcyan = new Color(0, 139, 139);
/** Color. **/
public static final Color darkgoldenrod = new Color(184, 134, 11);
/** Color. **/
public static final Color darkgray = new Color(169, 169, 169);
/** Color. **/
public static final Color darkgreen = new Color(0, 100, 0);
/** Color. **/
public static final Color darkkhaki = new Color(189, 183, 107);
/** Color. **/
public static final Color darkmagenta = new Color(139, 0, 139);
/** Color. **/
public static final Color darkolivegreen = new Color(85, 107, 47);
/** Color. **/
public static final Color darkorange = new Color(255, 140, 0);
/** Color. **/
public static final Color darkorchid = new Color(153, 50, 204);
/** Color. **/
public static final Color darkred = new Color(139, 0, 0);
/** Color. **/
public static final Color darksalmon = new Color(233, 150, 122);
/** Color. **/
public static final Color darkseagreen = new Color(143, 188, 143);
/** Color. **/
public static final Color darkslateblue = new Color(72, 61, 139);
/** Color. **/
public static final Color darkslategray = new Color(47, 79, 79);
/** Color. **/
public static final Color darkturquoise = new Color(0, 206, 209);
/** Color. **/
public static final Color darkviolet = new Color(148, 0, 211);
/** Color. **/
public static final Color deeppink = new Color(255, 20, 147);
/** Color. **/
public static final Color deepskyblue = new Color(0, 191, 255);
/** Color. **/
public static final Color dimgray = new Color(105, 105, 105);
/** Color. **/
public static final Color dodgerblue = new Color(30, 144, 255);
/** Color. **/
public static final Color firebrick = new Color(178, 34, 34);
/** Color. **/
public static final Color floralwhite = new Color(255, 250, 240);
/** Color. **/
public static final Color forestgreen = new Color(34, 139, 34);
/** Color. **/
public static final Color fuchsia = new Color(255, 0, 255);
/** Color. **/
public static final Color gainsboro = new Color(220, 220, 220);
/** Color. **/
public static final Color ghostwhite = new Color(248, 248, 255);
/** Color. **/
public static final Color gold = new Color(255, 215, 0);
/** Color. **/
public static final Color goldenrod = new Color(218, 165, 32);
/** Color. **/
public static final Color gray = new Color(128, 128, 128);
/** Color. **/
public static final Color green = new Color(0, 128, 0);
/** Color. **/
public static final Color greenyellow = new Color(173, 255, 47);
/** Color. **/
public static final Color honeydew = new Color(240, 255, 240);
/** Color. **/
public static final Color hotpink = new Color(255, 105, 180);
/** Color. **/
public static final Color indianred = new Color(205, 92, 92);
/** Color. **/
public static final Color indigo = new Color(75, 0, 130);
/** Color. **/
public static final Color ivory = new Color(255, 240, 240);
/** Color. **/
public static final Color khaki = new Color(240, 230, 140);
/** Color. **/
public static final Color lavender = new Color(230, 230, 250);
/** Color. **/
public static final Color lavenderblush = new Color(255, 240, 245);
/** Color. **/
public static final Color lawngreen = new Color(124, 252, 0);
/** Color. **/
public static final Color lemonchiffon = new Color(255, 250, 205);
/** Color. **/
public static final Color lightblue = new Color(173, 216, 230);
/** Color. **/
public static final Color lightcoral = new Color(240, 128, 128);
/** Color. **/
public static final Color lightcyan = new Color(224, 255, 255);
/** Color. **/
public static final Color lightgoldenrodyellow = new Color(250, 250, 210);
/** Color. **/
public static final Color lightgreen = new Color(144, 238, 144);
/** Color. **/
public static final Color lightgrey = new Color(211, 211, 211);
/** Color. **/
public static final Color lightpink = new Color(255, 182, 193);
/** Color. **/
public static final Color lightsalmon = new Color(255, 160, 122);
/** Color. **/
public static final Color lightseagreen = new Color(32, 178, 170);
/** Color. **/
public static final Color lightskyblue = new Color(135, 206, 250);
/** Color. **/
public static final Color lightslategray = new Color(119, 136, 153);
/** Color. **/
public static final Color lightsteelblue = new Color(176, 196, 222);
/** Color. **/
public static final Color lightyellow = new Color(255, 255, 224);
/** Color. **/
public static final Color lime = new Color(0, 255, 0);
/** Color. **/
public static final Color limegreen = new Color(50, 205, 50);
/** Color. **/
public static final Color linen = new Color(250, 240, 230);
/** Color. **/
public static final Color magenta = new Color(255, 0, 255);
/** Color. **/
public static final Color maroon = new Color(128, 0, 0);
/** Color. **/
public static final Color mediumaquamarine = new Color(102, 205, 170);
/** Color. **/
public static final Color mediumblue = new Color(0, 0, 205);
/** Color. **/
public static final Color mediumorchid = new Color(186, 85, 211);
/** Color. **/
public static final Color mediumpurple = new Color(147, 112, 219);
/** Color. **/
public static final Color mediumseagreen = new Color(60, 179, 113);
/** Color. **/
public static final Color mediumslateblue = new Color(123, 104, 238);
/** Color. **/
public static final Color mediumspringgreen = new Color(0, 250, 154);
/** Color. **/
public static final Color mediumturquoise = new Color(72, 209, 204);
/** Color. **/
public static final Color mediumvioletred = new Color(199, 21, 133);
/** Color. **/
public static final Color midnightblue = new Color(25, 25, 112);
/** Color. **/
public static final Color mintcream = new Color(245, 255, 250);
/** Color. **/
public static final Color mistyrose = new Color(255, 228, 225);
/** Color. **/
public static final Color mocassin = new Color(255, 228, 181);
/** Color. **/
public static final Color navajowhite = new Color(255, 222, 173);
/** Color. **/
public static final Color navy = new Color(0, 0, 128);
/** Color. **/
public static final Color oldlace = new Color(253, 245, 230);
/** Color. **/
public static final Color olive = new Color(128, 128, 0);
/** Color. **/
public static final Color olivedrab = new Color(107, 142, 35);
/** Color. **/
public static final Color orange = new Color(255, 165, 0);
/** Color. **/
public static final Color orangered = new Color(255, 69, 0);
/** Color. **/
public static final Color orchid = new Color(218, 112, 214);
/** Color. **/
public static final Color palegoldenrod = new Color(238, 232, 170);
/** Color. **/
public static final Color palegreen = new Color(152, 251, 152);
/** Color. **/
public static final Color paleturquoise = new Color(175, 238, 238);
/** Color. **/
public static final Color palevioletred = new Color(219, 112, 147);
/** Color. **/
public static final Color papayawhip = new Color(255, 239, 213);
/** Color. **/
public static final Color peachpuff = new Color(255, 218, 185);
/** Color. **/
public static final Color peru = new Color(205, 133, 63);
/** Color. **/
public static final Color pink = new Color(255, 192, 203);
/** Color. **/
public static final Color plum = new Color(221, 160, 221);
/** Color. **/
public static final Color powderblue = new Color(176, 224, 230);
/** Color. **/
public static final Color purple = new Color(128, 0, 128);
/** Color. **/
public static final Color red = new Color(255, 0, 0);
/** Color. **/
public static final Color rosybrown = new Color(188, 143, 143);
/** Color. **/
public static final Color royalblue = new Color(65, 105, 225);
/** Color. **/
public static final Color saddlebrown = new Color(139, 69, 19);
/** Color. **/
public static final Color salmon = new Color(250, 128, 114);
/** Color. **/
public static final Color sandybrown = new Color(244, 164, 96);
/** Color. **/
public static final Color seagreen = new Color(46, 139, 87);
/** Color. **/
public static final Color seashell = new Color(255, 245, 238);
/** Color. **/
public static final Color sienna = new Color(160, 82, 45);
/** Color. **/
public static final Color silver = new Color(192, 192, 192);
/** Color. **/
public static final Color skyblue = new Color(135, 206, 235);
/** Color. **/
public static final Color slateblue = new Color(106, 90, 205);
/** Color. **/
public static final Color slategray = new Color(112, 128, 144);
/** Color. **/
public static final Color snow = new Color(255, 250, 250);
/** Color. **/
public static final Color springgreen = new Color(0, 255, 127);
/** Color. **/
public static final Color steelblue = new Color(70, 138, 180);
/** Color. **/
public static final Color tan = new Color(210, 180, 140);
/** Color. **/
public static final Color teal = new Color(0, 128, 128);
/** Color. **/
public static final Color thistle = new Color(216, 191, 216);
/** Color. **/
public static final Color tomato = new Color(253, 99, 71);
/** Color. **/
public static final Color turquoise = new Color(64, 224, 208);
/** Color. **/
public static final Color violet = new Color(238, 130, 238);
/** Color. **/
public static final Color wheat = new Color(245, 222, 179);
/** Color. **/
public static final Color white = new Color(255, 255, 255);
/** Color. **/
public static final Color whitesmoke = new Color(245, 245, 245);
/** Color. **/
public static final Color yellow = new Color(255, 255, 0);
/** Color. **/
public static final Color yellowgreen = new Color(154, 205, 50);
/**
* Limbs color list.
*/
public static final Color[] LIMBS_COLORS = new Color[] { new Color(153, 0, 0), // 0 (1
// ->
// 2)
new Color(153, 51, 0), // 1 (1 -> 5)
new Color(153, 102, 0), // 2 (2 -> 3)
new Color(153, 153, 0), // 3 (3 -> 4)
new Color(102, 153, 0), // 4 (5 -> 6)
new Color(51, 153, 0), // 5 (6 -> 7)
new Color(0, 153, 0), // 6 (1 -> 8)
new Color(0, 153, 51), // 7 (8 -> 9)
new Color(0, 153, 102), // 8 (9 -> 10)
new Color(0, 153, 153), // 9 (1 -> 11)
new Color(0, 102, 153), // 10 (11 -> 12)
new Color(0, 51, 153), // 11 (12 -> 13)
new Color(0, 0, 153), // 12 (1 -> 0)
new Color(51, 0, 153), // 13 (0 -> 14)
new Color(102, 0, 153), // 14 (14 -> 16)
new Color(153, 0, 153), // 15 (0 -> 15)
new Color(153, 0, 102), // 16 (15 -> 17)
new Color(153, 0, 51), // 17 (2 -> 16)
new Color(153, 153, 153), // 18 (5 -> 17)
};
/**
* Constants lists.
*/
private static final Color[] CLASS_COLOR = new Color[] { aliceblue, chartreuse, aqua, aquamarine, azure, beige,
bisque, blanchedalmond, blueviolet, burlywood, cadetblue, antiquewhite, chocolate, coral, cornflowerblue,
cornsilk, crimson, cyan, darkcyan, darkgoldenrod, darkgray, darkkhaki, darkorange, darkorchid, darksalmon,
darkseagreen, darkturquoise, darkviolet, deeppink, deepskyblue, dodgerblue, firebrick, floralwhite,
forestgreen, fuchsia, gainsboro, ghostwhite, gold, goldenrod, salmon, tan, honeydew, hotpink, indianred,
ivory, khaki, lavender, lavenderblush, lawngreen, lemonchiffon, lightblue, lightcoral, lightcyan,
lightgoldenrodyellow, lightgreen, lightgrey, lightgreen, lightpink, lightsalmon, lightseagreen,
lightskyblue, lightslategray, lightslategray, lightsteelblue, lightyellow, lime, limegreen, linen, magenta,
mediumaquamarine, mediumorchid, mediumpurple, mediumseagreen, mediumslateblue, mediumspringgreen,
mediumturquoise, mediumvioletred, mintcream, mistyrose, mocassin, navajowhite, oldlace, olive, olivedrab,
orange, orangered, orchid, palegoldenrod, palegreen, paleturquoise, palevioletred, papayawhip, peachpuff,
peru, pink, plum, powderblue, purple, red, rosybrown, royalblue, saddlebrown, green, sandybrown, seagreen,
seashell, sienna, silver, skyblue, slateblue, slategray, slategray, snow, springgreen, steelblue,
greenyellow, teal, thistle, tomato, turquoise, violet, wheat, white, whitesmoke, yellow, yellowgreen };
/**
* List of constants.
*/
public static final Color[] CLASS_COLOR2 = new Color[] { yellow, yellowgreen, turquoise, springgreen, skyblue,
slateblue, red, violet, olivedrab, royalblue, darkorange, mediumblue, deeppink, chartreuse, orchid,
palegreen, aqua, orange, navy };
/**
* Return different color for each Id. It rotates when the ID exceeds the number of
* predefined colors.
* @param id the unique id to pick color for.
* @return a distinct color computed from the input #id
*/
public static Color getClassColor(int id) {
return CLASS_COLOR[id % CLASS_COLOR.length];
}
/**
* Augments the input image fromMemory a labeled rectangle (e.g. bounding box)
* fromMemory coordinates: (x1, y1, x2, y2).
* @param image Input image to be augmented fromMemory labeled rectangle.
* @param cid Unique id used to select the color of the rectangle. Used only if the
* colorAgnostic is set to false.
* @param title rectangle title
* @param x1 top left corner for the bounding box
* @param y1 top left corner for the bounding box
* @param x2 bottom right corner for the bounding box
* @param y2 bottom right corner for the bounding box
* @param colorAgnostic If set to false the cid is used to select the bounding box
* color. Uses the AGNOSTIC_COLOR otherwise.
*/
public static void drawBoundingBox(BufferedImage image, int cid, String title, int x1, int y1, int x2, int y2,
boolean colorAgnostic) {
Graphics2D g = image.createGraphics();
g.setRenderingHint(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_ON);
Color labelColor = colorAgnostic ? AGNOSTIC_COLOR : GraphicsUtils.getClassColor(cid);
g.setColor(labelColor);
g.setFont(DEFAULT_FONT);
FontMetrics fontMetrics = g.getFontMetrics();
Stroke oldStroke = g.getStroke();
g.setStroke(new BasicStroke(LINE_THICKNESS));
g.drawRect(x1, y1, (x2 - x1), (y2 - y1));
g.setStroke(oldStroke);
Rectangle2D rect = fontMetrics.getStringBounds(title, g);
g.setColor(labelColor);
g.fillRect(x1, y1 - fontMetrics.getAscent(), (int) rect.getWidth() + 2 * TITLE_OFFSET, (int) rect.getHeight());
g.setColor(getTextColor(labelColor));
g.drawString(title, x1 + TITLE_OFFSET, y1);
}
/**
* Depends on the darkness of the background, pick a dark or light DEFAULT_FONT color.
* @param backGroundColor background color within which the text is drawn
* @return a text color, that contrast to the given background color.
*/
private static Color getTextColor(Color backGroundColor) {
double y = (299 * backGroundColor.getRed() + 587 * backGroundColor.getGreen() + 114 * backGroundColor.getBlue())
/ 1000;
return y >= 128 ? Color.black : Color.white;
}
public static BufferedImage createMaskImage(float[][] maskPixels, int scaledWidth, int scaledHeight,
Color maskColor) {
int maskWidth = maskPixels.length;
int maskHeight = maskPixels[0].length;
int[] maskArray = new int[maskWidth * maskHeight];
int k = 0;
for (int i = 0; i < maskHeight; i++) {
for (int j = 0; j < maskWidth; j++) {
maskArray[k++] = grayScaleToARGB(maskPixels[i][j], maskColor);
}
}
// Turn the pixel array into image;
BufferedImage maskImage = new BufferedImage(maskWidth, maskHeight, BufferedImage.TYPE_INT_ARGB);
maskImage.setRGB(0, 0, maskWidth, maskHeight, maskArray, 0, maskWidth);
// Stretch the image to fit the target box width and height!
return toBufferedImage(maskImage.getScaledInstance(scaledWidth, scaledHeight, Image.SCALE_DEFAULT));
}
/**
* Converts an gray scale (e.g. value between 0 to 1) into ARGB.
* @param grayScale - value between 0 and 1
* @param maskColor - desired mask color
* @return Returns a ARGB color based on the grayscale and the mask colors
*/
private static int grayScaleToARGB(float grayScale, Color maskColor) {
if (maskColor != null) {
float r = col(maskColor.getRed(), grayScale);
float g = col(maskColor.getGreen(), grayScale);
float b = col(maskColor.getBlue(), grayScale);
float t = grayScale * 0.7f;
return new Color(r, g, b, t).getRGB();
}
return new Color(grayScale, grayScale, grayScale, grayScale).getRGB();
}
private static float col(int channelColor, float grayScale) {
// return ((float) channelColor / 255) * grayScale;
return ((float) channelColor / 255);
}
public static BufferedImage toBufferedImage(Image img) {
// if (img instanceof BufferedImage) {
// return (BufferedImage) img;
// }
// Create a buffered image fromMemory transparency
BufferedImage bimage = new BufferedImage(img.getWidth(null), img.getHeight(null), BufferedImage.TYPE_INT_ARGB);
// Draw the image on to the buffered image
Graphics2D bGr = bimage.createGraphics();
bGr.drawImage(img, 0, 0, null);
bGr.dispose();
// Return the buffered image
return bimage;
}
public static BufferedImage overlayImages(BufferedImage bgImage, BufferedImage fgImage, int fgX, int fgY) {
// Foreground image width and height cannot be greater than background image width
// and height.
if (fgImage.getHeight() > bgImage.getHeight() || fgImage.getWidth() > fgImage.getWidth()) {
throw new IllegalArgumentException("Foreground Image Is Bigger In One or Both Dimensions"
+ "nCannot proceed fromMemory overlay." + "nn Please use smaller Image for foreground");
}
// Create a Graphics from the background image
Graphics2D g = bgImage.createGraphics();
// Set Antialias Rendering
g.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
// Draw background image at location (0,0)
g.drawImage(bgImage, 0, 0, null);
// Draw foreground image at location (fgX,fgy)
g.drawImage(fgImage, fgX, fgY, null);
g.dispose();
return bgImage;
}
/**
* Convert {@link BufferedImage} to byte array.
* @param image the image to be converted
* @param format the output image format
* @return New array of bytes
*/
public static byte[] toImageByteArray(BufferedImage image, String format) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
ImageIO.write(image, format, baos);
byte[] bytes = baos.toByteArray();
return bytes;
}
catch (IOException e) {
throw new IllegalStateException(e);
}
finally {
try {
baos.close();
}
catch (IOException e) {
throw new IllegalStateException(e);
}
}
}
/**
* @param bufferedImage buffer to be converted in to raw array
* @return flat byte array representing the buffered image
*/
public static byte[] toRawByteArray(BufferedImage bufferedImage) {
return ((DataBufferByte) bufferedImage.getRaster().getDataBuffer()).getData();
}
/**
* BufferedImage.TYPE_3BYTE_BGR, BufferedImage.TYPE_3BYTE_BGR.
* @param image image to be converted into buffer
* @param imageType desired type
* @return BufferedImage
*/
public static BufferedImage toBufferedImageType(BufferedImage image, int imageType) {
if (image.getType() == imageType) {
return image;
}
BufferedImage outputImage = new BufferedImage(image.getWidth(), image.getHeight(), imageType);
outputImage.getGraphics().drawImage(image, 0, 0, null);
return outputImage;
}
public static byte[] toImageToBytes(String imageUri) throws IOException {
try (InputStream is = new DefaultResourceLoader().getResource(imageUri).getInputStream()) {
return IOUtils.toByteArray(is);
}
}
/**
* Loads a resource as byte array. Supports http:, file: and classpath: URI schemas
* @param resourceUri resource URI
* @return Returns resources referred by the resourceUri as a byte array
* @throws IOException failure due to missing resource or invalid URI
*/
public static byte[] loadAsByteArray(String resourceUri) throws IOException {
Resource expectedPoseResponse = new DefaultResourceLoader().getResource(resourceUri);
try (InputStream is = expectedPoseResponse.getInputStream()) {
return IOUtils.toByteArray(is);
}
}
}

View File

@@ -1,47 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow.deprecated;
import java.util.function.Function;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Maps domain objects into JSON strings.
*
* @author Christian Tzolov
*/
public class JsonMapperFunction implements Function<Object, String> {
private static final Log logger = LogFactory.getLog(JsonMapperFunction.class);
@Override
public String apply(Object o) {
try {
return new ObjectMapper().writeValueAsString(o);
}
catch (JsonProcessingException e) {
logger.error("Failed to encode the object detections into JSON message", e);
}
return "ERROR";
}
}

View File

@@ -1,134 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow.deprecated;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Function;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Session.Runner;
import org.tensorflow.Tensor;
import org.springframework.cloud.fn.common.tensorflow.util.CachedModelExtractor;
import org.springframework.cloud.fn.common.tensorflow.util.ModelExtractor;
import org.springframework.core.io.Resource;
/**
* @author Christian Tzolov
*/
public class TensorFlowService implements Function<Map<String, Tensor<?>>, Map<String, Tensor<?>>>, AutoCloseable {
private static final Log logger = LogFactory.getLog(TensorFlowService.class);
private final Session session;
private final List<String> fetchedNames;
private final boolean autoCloseFeedTensors;
public TensorFlowService(Resource modelLocation, List<String> fetchedNames) {
this(modelLocation, fetchedNames, false);
}
public TensorFlowService(Resource modelLocation, List<String> fetchedNames, boolean cacheModel) {
this(modelLocation, fetchedNames, cacheModel, false);
}
public TensorFlowService(Resource modelLocation, List<String> fetchedNames, boolean cacheModel,
boolean autoCloseFeedTensors) {
if (logger.isInfoEnabled()) {
logger.info("Loading TensorFlow graph model: " + modelLocation);
}
this.autoCloseFeedTensors = autoCloseFeedTensors;
this.fetchedNames = fetchedNames;
Graph graph = new Graph();
byte[] model = cacheModel ? new CachedModelExtractor().getModel(modelLocation)
: new ModelExtractor().getModel(modelLocation);
graph.importGraphDef(model);
this.session = new Session(graph);
}
/**
* Evaluates a pre-trained tensorflow model (encoded as {@link Graph}). Use the feeds
* parameter to feed in the model input data and fetch-names to specify the output
* tensors.
* @param feeds Named map of input tensors.
* @return Returns the computed output tensors. The names of the output tensors is
* defined by the fetchedNames argument
*/
@Override
public Map<String, Tensor<?>> apply(Map<String, Tensor<?>> feeds) {
Runner runner = this.session.runner();
// Keep tensor references to release them in the finally block
Tensor[] feedTensors = new Tensor[feeds.size()];
try {
// Feed in the input named tensors
int inputIndex = 0;
for (Entry<String, Tensor<?>> e : feeds.entrySet()) {
String feedName = e.getKey();
feedTensors[inputIndex] = e.getValue();
runner = runner.feed(feedName, feedTensors[inputIndex]);
inputIndex++;
}
// Set the tensor name to be fetched after the evaluation
for (String fetchName : this.fetchedNames) {
runner.fetch(fetchName);
}
// Evaluate the input
List<Tensor<?>> outputTensors = runner.run();
// Extract the output tensors
Map<String, Tensor<?>> outTensorMap = new HashMap<>();
for (int outputIndex = 0; outputIndex < this.fetchedNames.size(); outputIndex++) {
outTensorMap.put(this.fetchedNames.get(outputIndex), outputTensors.get(outputIndex));
}
return outTensorMap;
}
finally {
if (this.autoCloseFeedTensors) {
// Release all feed tensors
for (Tensor tensor : feedTensors) {
if (tensor != null) {
tensor.close();
}
}
}
}
}
@Override
public void close() {
logger.info("Close TensorFlow Session!");
if (this.session != null) {
this.session.close();
}
}
}

View File

@@ -1,157 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow.util;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Utilities for AutoCloseable classes. Based on the Apache Drill AutoCloseables
* implementation.
*/
public final class AutoCloseables {
private AutoCloseables() {
}
private static final Logger LOGGER = LoggerFactory.getLogger(AutoCloseables.class);
public static AutoCloseable all(final Collection<? extends AutoCloseable> autoCloseables) {
return () -> close(autoCloseables);
}
public static AutoCloseable all(final Map<?, ? extends AutoCloseable>... autoCloseables) {
return () -> close(autoCloseables);
}
/**
* Closes all autoCloseables if not null and suppresses exceptions by adding them to
* t.
* @param t the throwable to add suppressed exception to
* @param autoCloseables the closeables to close
*/
public static void close(Throwable t, AutoCloseable... autoCloseables) {
close(t, Arrays.asList(autoCloseables));
}
/**
* Closes all autoCloseables if not null and suppresses exceptions by adding them to
* t.
* @param t the throwable to add suppressed exception to
* @param autoCloseables the closeables to close
*/
public static void close(Throwable t, Collection<? extends AutoCloseable> autoCloseables) {
try {
close(autoCloseables);
}
catch (Exception e) {
t.addSuppressed(e);
}
}
/**
* Closes all autoCloseables if not null and suppresses subsequent exceptions if more
* than one.
* @param autoCloseables the closeables to close
*/
public static void close(AutoCloseable... autoCloseables) throws Exception {
close(Arrays.asList(autoCloseables));
}
/**
* Closes all autoCloseables if not null and suppresses subsequent exceptions if more
* than one.
* @param autoCloseables the closeables to close
*/
public static void close(Iterable<? extends AutoCloseable> autoCloseables) throws Exception {
Exception topLevelException = null;
for (AutoCloseable closeable : autoCloseables) {
try {
if (closeable != null) {
closeable.close();
}
}
catch (Exception e) {
if (topLevelException == null) {
topLevelException = e;
}
else {
topLevelException.addSuppressed(e);
}
}
}
if (topLevelException != null) {
throw topLevelException;
}
}
/**
* Closes all autoCloseables entry values if not null and suppresses subsequent
* exceptions if more than one.
* @param closableMaps the closeables to close
*/
public static void close(Map<?, ? extends AutoCloseable>... closableMaps) throws Exception {
Exception topLevelException = null;
for (Map<?, ? extends AutoCloseable> closableMap : closableMaps) {
for (Object key : closableMap.keySet()) {
AutoCloseable closeable = closableMap.get(key);
try {
if (closeable != null) {
closeable.close();
}
}
catch (Exception e) {
if (topLevelException == null) {
topLevelException = e;
}
else {
topLevelException.addSuppressed(e);
}
}
}
closableMap.keySet();
}
if (topLevelException != null) {
throw topLevelException;
}
}
/**
* Close all without caring about thrown exceptions.
* @param closeables - array containing auto closeables
*/
public static void closeSilently(AutoCloseable... closeables) {
Arrays.stream(closeables).filter(Objects::nonNull).forEach(target -> {
try {
target.close();
}
catch (Exception e) {
LOGGER.warn(String.format("Exception was thrown while closing auto closeable: %s", target), e);
}
});
}
}

View File

@@ -1,110 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow.util;
import java.io.File;
import java.io.FileInputStream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
/**
* Extends the {@link ModelExtractor} to allow keeping a local copy (cache) of the loaded
* model (protobuf) files.
*
* @author Christian Tzolov
*/
public class CachedModelExtractor extends ModelExtractor {
private static final Log logger = LogFactory.getLog(CachedModelExtractor.class);
/**
* Parent folder under which the model files are cached.
*/
public String cacheRootDirectory = new File(System.getProperty("java.io.tmpdir"), "mind-model").getAbsolutePath();
public String getCacheRootDirectory() {
return cacheRootDirectory;
}
public void setCacheRootDirectory(String cacheRootDirectory) {
this.cacheRootDirectory = cacheRootDirectory;
}
public CachedModelExtractor() {
super();
}
public CachedModelExtractor(String frozenGraphFileExtension) {
super(frozenGraphFileExtension);
}
@Override
public byte[] getModel(String modelUri) {
return this.getModel(new DefaultResourceLoader().getResource(modelUri));
}
@Override
public byte[] getModel(Resource modelResource) {
try {
File rootFolder = new File(this.cacheRootDirectory);
if (!rootFolder.exists()) {
logger.info("Create Model Cache root folder: " + rootFolder.getAbsolutePath());
rootFolder.mkdirs();
}
Validate.isTrue(rootFolder.isDirectory(), "The cache root folder must be a Directory");
String fileName = modelResource.getFilename();
String fragment = modelResource.getURI().getFragment();
File cachedFile = StringUtils.isEmpty(fragment) ? new File(rootFolder, fileName)
: new File(rootFolder, fileName + "_" + fragment);
if (cachedFile.exists()) {
logger.info("Load model " + modelResource.toString() + " from cache: " + cacheRootDirectory);
return IOUtils.toByteArray(new FileInputStream(cachedFile));
}
byte[] model = super.getModel(modelResource);
// cache the file
FileUtils.writeByteArrayToFile(cachedFile, model);
logger.info("Caching the " + modelResource.toString() + " model at: " + cachedFile);
return model;
}
catch (Exception e) {
throw new IllegalStateException("Failed to extract a model from: " + modelResource.getDescription(), e);
}
}
public void emptyModelCache() {
File rootFolder = new File(this.cacheRootDirectory);
if (rootFolder.exists()) {
logger.info("Empty Model Cache at:" + rootFolder.getAbsolutePath());
rootFolder.delete();
rootFolder.mkdirs();
}
}
}

View File

@@ -1,253 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow.util;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;
import java.util.Optional;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.ArchiveInputStream;
import org.apache.commons.compress.archivers.ArchiveStreamFactory;
import org.apache.commons.compress.compressors.CompressorInputStream;
import org.apache.commons.compress.compressors.CompressorStreamFactory;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
/**
* Extracts a pre-trained (frozen) Tensorflow model URI into byte array. The 'http://',
* 'file://' and 'classpath://' URI schemas are supported.
*
* Models can be extract either from raw files or form compressed archives. When extracted
* from an archive the model file name can optionally be provided as an URI fragment. For
* example for resource: http://myarchive.tar.gz#model.pb the myarchive.tar.gz is
* traversed to uncompress and extract the model.pb file as byte array. If the file name
* is not provided as URI fragment then the first file in the archive with extension .pb
* is extracted.
*
* @author Christian Tzolov
*/
public class ModelExtractor {
private static final String DEFAULT_FROZEN_GRAPH_FILE_EXTENSION = ".pb";
/**
* When an archive resource if referred, but no fragment URI is provided (to specify
* the target file name in the archive) then the extractor selects the first file in
* the archive with the extension that match the frozenGraphFileExtension (defaults to
* .pb).
*/
public final String frozenGraphFileExtension;
public ModelExtractor() {
this(DEFAULT_FROZEN_GRAPH_FILE_EXTENSION);
}
public ModelExtractor(String frozenGraphFileExtension) {
this.frozenGraphFileExtension = frozenGraphFileExtension;
}
public byte[] getModel(String modelUri) {
return getModel(new DefaultResourceLoader().getResource(modelUri));
}
public byte[] getModel(Resource modelResource) {
Validate.notNull(modelResource, "Not null model resource is required!");
try (InputStream is = modelResource.getInputStream(); InputStream bi = new BufferedInputStream(is)) {
String[] archiveCompressor = detectArchiveAndCompressor(modelResource.getFilename());
String archive = archiveCompressor[0];
String compressor = archiveCompressor[1];
String fragment = modelResource.getURI().getFragment();
if (StringUtils.isNotBlank(compressor)) {
try (CompressorInputStream cis = new CompressorStreamFactory().createCompressorInputStream(compressor,
bi)) {
if (StringUtils.isNotBlank(archive)) {
try (ArchiveInputStream ais = new ArchiveStreamFactory().createArchiveInputStream(archive,
cis)) {
// Compressor fromMemory Archive
return findInArchiveStream(fragment, ais);
}
}
else { // Compressor only
return IOUtils.toByteArray(cis);
}
}
}
else if (StringUtils.isNotBlank(archive)) { // Archive only
try (ArchiveInputStream ais = new ArchiveStreamFactory().createArchiveInputStream(archive, bi)) {
return findInArchiveStream(fragment, ais);
}
}
else {
// No compressor nor Archive
return IOUtils.toByteArray(bi);
}
}
catch (Exception e) {
throw new IllegalStateException("Failed to extract a model from: " + modelResource.getDescription(), e);
}
}
/**
* Traverses the Archive to find either an entry that matches the
* modelFileNameInArchive name (if not empty) or and entry that ends in .pb if the
* modelFileNameInArchive is empty.
* @param modelFileNameInArchive Optional name of the archive entry that represents
* the frozen model file. If empty the archive will be searched for the first entry
* that ends in .pb
* @param archive Archive stream to be traversed
*
*/
private byte[] findInArchiveStream(String modelFileNameInArchive, ArchiveInputStream archive) throws IOException {
ArchiveEntry entry;
while ((entry = archive.getNextEntry()) != null) {
// System.out.println(entry.getName() + " : " + entry.isDirectory());
if (archive.canReadEntryData(entry) && !entry.isDirectory()) {
if ((StringUtils.isNotBlank(modelFileNameInArchive) && entry.getName().endsWith(modelFileNameInArchive))
|| (!StringUtils.isNotBlank(modelFileNameInArchive)
&& entry.getName().endsWith(this.frozenGraphFileExtension))) {
return IOUtils.toByteArray(archive);
}
}
}
throw new IllegalArgumentException("No model is found in the archive");
}
/**
* Detect the Archive and the Compressor from the file extension.
* @param fileName File name with extension.
* @return Returns a tuple of the detected (Archive, Compressor). Null stands for not
* available archive or detector. The (null, null) response stands for no Archive or
* Compressor discovered.
*/
private String[] detectArchiveAndCompressor(String fileName) {
String normalizedFileName = fileName.trim().toLowerCase();
if (normalizedFileName.endsWith(".tar.gz") || normalizedFileName.endsWith(".tgz")
|| normalizedFileName.endsWith(".taz")) {
return new String[] { ArchiveStreamFactory.TAR, CompressorStreamFactory.GZIP };
}
else if (normalizedFileName.endsWith(".tar.bz2") || normalizedFileName.endsWith(".tbz2")
|| normalizedFileName.endsWith(".tbz")) {
return new String[] { ArchiveStreamFactory.TAR, CompressorStreamFactory.BZIP2 };
}
else if (normalizedFileName.endsWith(".cpgz")) {
return new String[] { ArchiveStreamFactory.CPIO, CompressorStreamFactory.GZIP };
}
else if (hasArchive(normalizedFileName)) {
return new String[] { findArchive(normalizedFileName).get(), null };
}
else if (hasCompressor(normalizedFileName)) {
return new String[] { null, findCompressor(normalizedFileName).get() };
}
else if (normalizedFileName.endsWith(".gzip")) {
return new String[] { null, CompressorStreamFactory.GZIP };
}
else if (normalizedFileName.endsWith(".bz2") || normalizedFileName.endsWith(".bz")) {
return new String[] { null, CompressorStreamFactory.BZIP2 };
}
// No archived/compressed
return new String[] { null, null };
}
private boolean hasArchive(String normalizedFileName) {
return findArchive(normalizedFileName).isPresent();
}
private Optional<String> findArchive(String normalizedFileName) {
return new ArchiveStreamFactory().getInputStreamArchiveNames()
.stream()
.filter(arch -> normalizedFileName.endsWith("." + arch))
.findFirst();
}
private boolean hasCompressor(String normalizedFileName) {
return findCompressor(normalizedFileName).isPresent();
}
private Optional<String> findCompressor(String normalizedFileName) {
return new CompressorStreamFactory().getInputStreamCompressorNames()
.stream()
.filter(compressor -> normalizedFileName.endsWith("." + compressor))
.findFirst();
}
static {
disableSslVerification();
}
private static void disableSslVerification() {
try {
// Create a trust manager that does not validate certificate chains
TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() {
public X509Certificate[] getAcceptedIssuers() {
return null;
}
public void checkClientTrusted(X509Certificate[] certs, String authType) {
}
public void checkServerTrusted(X509Certificate[] certs, String authType) {
}
} };
// Install the all-trusting trust manager
SSLContext sc = SSLContext.getInstance("SSL");
sc.init(null, trustAllCerts, new java.security.SecureRandom());
HttpsURLConnection.setDefaultSSLSocketFactory(sc.getSocketFactory());
// Create all-trusting host name verifier
HostnameVerifier allHostsValid = new HostnameVerifier() {
public boolean verify(String hostname, SSLSession session) {
return true;
}
};
// Install the all-trusting host verifier
HttpsURLConnection.setDefaultHostnameVerifier(allHostsValid);
}
catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}
catch (KeyManagementException e) {
e.printStackTrace();
}
}
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 33 KiB

View File

@@ -1,87 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import org.tensorflow.Tensor;
/**
* @author Christian Tzolov
*/
public class EnrichFromMemory implements AutoCloseable {
private final GraphRunner graph1;
private final GraphRunner graph2;
private final GraphRunner graph3;
public EnrichFromMemory() {
this.graph1 = new GraphRunner("x1", "y1").withGraphDefinition(
tf -> tf.withName("y1").math.mul(tf.withName("x1").placeholder(Integer.class), tf.constant(10)));
this.graph2 = new GraphRunner("x2", "y2").withGraphDefinition(
tf -> tf.withName("y2").math.mul(tf.withName("x2").placeholder(Integer.class), tf.constant(20)));
this.graph3 = new GraphRunner(Arrays.asList("x31", "x32"), Arrays.asList("y3"))
.withGraphDefinition(tf -> tf.withName("y3").math.add(tf.withName("x31").placeholder(Integer.class),
tf.withName("x32").placeholder(Integer.class)));
}
public int compute(Integer input) {
try (Tensor x = Tensor.create(input); GraphRunnerMemory memory = new GraphRunnerMemory();) {
Map<String, Tensor<?>> result = this.graph1.andThen(memory)
.andThen(graph2)
.andThen(memory)
.andThen(Functions.enrichFromMemory(memory, "y1")) // retrieves the
// graph1's y1 output
// and adds it as a
// parameter with the
// same name
.andThen(Functions.rename("y1", "x31", // renames the input y1 into x31
"y2", "x32" // renames the input y2 into x32
))
.andThen(graph3)
.andThen(memory)
.apply(Collections.singletonMap("x", x));
memory.getTensorMap().entrySet().forEach(e -> System.out.println(" " + e));
return result.get("y3").intValue();
}
}
@Override
public void close() {
this.graph1.close();
this.graph2.close();
}
public static void main(String[] args) {
try (EnrichFromMemory example = new EnrichFromMemory()) {
for (int x = 0; x < 5; x++) {
System.out.println("For x = " + x + ", y = " + example.compute(x));
}
}
}
}

View File

@@ -1,52 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Collections;
import java.util.Map;
import org.tensorflow.Tensor;
/**
* @author Christian Tzolov
*/
public final class FunctionComposition {
private FunctionComposition() {
}
// y = (x * 2) + 20
//
// y1 = x1 * 2 , where x1 == x
// y2 = x2 + 20 , where x2 == y1 and y = y2
public static void main(String[] args) {
try (GraphRunner graph1 = new GraphRunner("x1", "y1").withGraphDefinition(
tf -> tf.withName("y1").math.mul(tf.withName("x1").placeholder(Integer.class), tf.constant(2)));
GraphRunner graph2 = new GraphRunner("x2", "y2").withGraphDefinition(tf -> tf.withName("y2").math
.add(tf.withName("x2").placeholder(Integer.class), tf.constant(20)));
Tensor x = Tensor.create(10);) {
Map<String, Tensor<?>> result = graph1.andThen(graph2).apply(Collections.singletonMap("x", x));
System.out.println("Result is: " + result.get("y2").intValue());
// Result is: 40
}
}
}

View File

@@ -1,61 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import org.tensorflow.Tensor;
import org.tensorflow.op.core.Placeholder;
/**
* @author Christian Tzolov
*/
public final class FunctionCompositionMultipleInputsOutputs {
private FunctionCompositionMultipleInputsOutputs() {
}
// y = (x * 2) + (x * 3)
//
// y11 = x1 * 2 , where x1 == x
// y12 = x1 * 3 , where x1 == x
// y2 = x21 + x22 , where x21 == y11, x22 == y12 and y == y2
public static void main(String[] args) {
try (GraphRunner graph1 = new GraphRunner(Arrays.asList("x1"), Arrays.asList("y11", "y12"))
.withGraphDefinition(tf -> {
Placeholder<Integer> x1 = tf.withName("x1").placeholder(Integer.class);
tf.withName("y11").math.mul(x1, tf.constant(2));
tf.withName("y12").math.mul(x1, tf.constant(3));
});
GraphRunner graph2 = new GraphRunner(Arrays.asList("x21", "x22"), Arrays.asList("y2"))
.withGraphDefinition(tf -> tf.withName("y2").math.add(tf.withName("x21").placeholder(Integer.class),
tf.withName("x22").placeholder(Integer.class)));
Tensor x = Tensor.create(10);) {
Map<String, Tensor<?>> result = graph1.andThen(Functions.rename("y11", "x21", "y12", "x22"))
.andThen(graph2)
.apply(Collections.singletonMap("x", x));
System.out.println("Result is: " + result.get("y2").intValue()); // Result is:
// 50
}
}
}

View File

@@ -1,71 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.common.tensorflow;
import java.util.Collections;
import java.util.Map;
import org.tensorflow.Tensor;
/**
* @author Christian Tzolov
*/
public class ReleaseTensorParameters implements AutoCloseable {
private final GraphRunner graph1;
private final GraphRunner graph2;
public ReleaseTensorParameters() {
this.graph1 = new GraphRunner("x1", "y1").withGraphDefinition(
tf -> tf.withName("y1").math.mul(tf.withName("x1").placeholder(Integer.class), tf.constant(2)));
this.graph2 = new GraphRunner("x2", "y2").withGraphDefinition(
tf -> tf.withName("y2").math.add(tf.withName("x2").placeholder(Integer.class), tf.constant(20)));
}
// y = (x * 2) + 20
public int compute(Integer input) {
try (Tensor x = Tensor.create(input); GraphRunnerMemory memory = new GraphRunnerMemory();) {
Map<String, Tensor<?>> result = this.graph1.andThen(memory)
.andThen(graph2)
.andThen(memory)
.apply(Collections.singletonMap("x", x));
memory.getTensorMap().entrySet().forEach(e -> System.out.println(" " + e));
return result.get("y2").intValue();
}
}
@Override
public void close() {
this.graph1.close();
this.graph2.close();
}
public static void main(String[] args) {
try (ReleaseTensorParameters example = new ReleaseTensorParameters()) {
for (int x = 0; x < 5; x++) {
System.out.println("For x = " + x + ", y = " + example.compute(x));
}
}
}
}

View File

@@ -4,7 +4,6 @@ ext {
springCloudAwsVersion = '3.0.4'
debeziumVersion = '2.5.0.Final'
protobufVersion='3.25.2'
springIntegrationAws = 'org.springframework.integration:spring-integration-aws:3.0.5'
ftpserverCore = 'org.apache.ftpserver:ftpserver-core:1.2.0'
@@ -12,5 +11,4 @@ ext {
twitter4jStream = 'org.twitter4j:twitter4j-stream:4.0.7'
greenmail = 'com.icegreen:greenmail:2.1.0-alpha-3'
apacheCuratorTest = 'org.apache.curator:curator-test:5.5.0'
protobufJava = "com.google.protobuf:protobuf-java:$protobufVersion"
}
}

View File

@@ -1,103 +0,0 @@
:images-asciidoc: https://raw.githubusercontent.com/spring-cloud/stream-applications/master/functions/function/image-recognition-function/src/main/resources/images/
# Image Recognition
[.lead]
Java model inference library for the https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models[Inception], https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md#pre-trained-models[MobileNetV1] and https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet#pretrained-models[MobileNetV2] image recognition architectures.
Provides real-time recognition of the https://dl.bintray.com/big-data/generic/imagenet_comp_graph_label_strings.txt[LSVRC-2012-CLS categories] in the input images.
[cols="1,2",frame=none,grid=none]
|===
| image:{images-asciidoc}/image-augmented.jpg[alt=Inception 1,width=100%]
|The https://github.com/spring-cloud/stream-applications/blob/master/functions/function/image-recognition-function/src/main/java/org/springframework/cloud/fn/image/recognition/ImageRecognition.java[ImageRecognition] takes an image and outputs a list of probable categories the image contains. The response is represented by https://github.com/spring-cloud/stream-applications/blob/master/functions/function/image-recognition-function/src/main/java/org/springframework/cloud/fn/image/recognition/RecognitionResponse.java[RecognitionResponse] class.
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/deprecated/JsonMapperFunction.java[JsonMapperFunction] permits
converting the `RecognitionResponse` into JSON objects and the
https://github.com/spring-cloud/stream-applications/blob/master/functions/function/image-recognition-function/src/main/java/org/springframework/cloud/fn/image/recognition/ImageRecognitionAugmenter.java[ImageRecognitionAugmenter] can augment the input image with the detected categories (as shown in pic. 1).
|===
## Usage
Add the `image-recognition` dependency to the pom (use the latest version available):
[source,xml]
----
<dependency>
<groupId>org.springframework.cloud.fn</groupId>
<artifactId>image-recognition-function</artifactId>
<version>${revision}</version>
</dependency>
----
#### Example 1: Image Recognition
The following snippet demonstrates how to use the `ImageRecognition` for detecting the categories present in an input image.
It also shows how to convert the result into JSON format and augment the input image with the detected category labels.
[source,java,linenums]
----
ImageRecognition recognitionService = ImageRecognition.mobilenetModeV2(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz#mobilenet_v2_1.4_224_frozen.pb", //<1>
224, //<2>
5, //<3>
true); //<4>
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/giant_panda_in_beijing_zoo_1.jpg"); //<5>
List<RecognitionResponse> recognizedObjects = recognitionService.recognize(inputImage); //<6>
----
<1> Downloads and loads a pre-trained `mobilenet_v2_1.4_224_frozen.pb` model.
Mind that on first attempt it will download few hundreds of MBs.
The consecutive runs will use the cached copy (5) instead.
The category labels for the MobileNetV2 are resolved from `src/main/resources/labels/mobilenet_labels.txt`.
<2> The wxh seize of the input normalized image.
<3> Top K result to return.
<4> Cache the model on the local file system.
<5> Load the image to recognise.
<6> Return a map of the top-k most probable category names and their probabilities.
The `ImageRecognition.mobilenetModeV1` and `ImageRecognition.inception` factory methods help to load and configure pre-trained mobilenetModeV1 and and Inception models.
Next you can convert the result in JSON format.
[source,java,linenums]
----
String jsonRecognizedObjects = new JsonMapperFunction().apply(recognizedObjects);
----
.Sample Image Recognition JSON representation
[source,json]
----
[{"label":"giant panda","probability":0.9946687817573547},{"label":"Arctic fox","probability":0.0036631098482757807},{"label":"ice bear","probability":3.3782739774324E-4},{"label":"American black bear","probability":2.3452856112271547E-4},{"label":"skunk","probability":1.6454080468975008E-4}]
----
Use the `ImageRecognitionAugmenter` to draw the recognise categories on top of the input image.
[source,java,linenums]
----
byte[] augmentedImage = new ImageRecognitionAugmenter().apply(inputImage, recognizedObjects); //<1>
IOUtils.write(augmentedImage, new FileOutputStream("./image-recognition/target/image-augmented.jpg"));//<2>
----
<1> Augment the image with the recognized categories (uses Java2D internally).
<2> Stores the augmented image as `image-augmented.jpg` image file.
.Augmented image-augmented.jpg file
image:{images-asciidoc}/image-recognition-panda-augmented.jpg[alt=Augmented,width=30%]
## Models
This implementation supports all pre-trained https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models[Inception], https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md#pre-trained-models[MobileNetV1] and https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet#pretrained-models[MobileNetV2] models.
Following URI notation can be used to download any of the models directly from the zoo.
----
http://<zoo model tar.gz url>#<frozen inference graph name.pb>
----
The `<frozen inference graph name.pb>` is the frozen model file name within the archive.
TIP: To speedup the bootstrap performance you may consider extracting the model and caching it locally.
Then you can use the `file://path-to-my-local-copy` URI schema to access it.
NOTE: It is important to use the labels that correspond to the model being used!
Table below highlights this mapping.

View File

@@ -1,3 +0,0 @@
dependencies {
api project(':spring-tensorflow-common')
}

View File

@@ -1,285 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.image.recognition;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.image.DecodeJpeg;
import org.tensorflow.op.nn.TopK;
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
import org.springframework.cloud.fn.common.tensorflow.GraphRunnerMemory;
import org.springframework.cloud.fn.common.tensorflow.ProtoBufGraphDefinition;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.util.StreamUtils;
/**
* @author Christian Tzolov
*/
public class ImageRecognition implements AutoCloseable {
private final List<String> labels;
private final GraphRunner imageNormalization;
private final GraphRunner imageRecognition;
private final GraphRunner maxProbability;
private final GraphRunner topKProbabilities;
/**
* Instead of creating the {@link ImageRecognition} service explicitly via the
* constructor, you should consider the convenience factory methods below. E.g.
*
* {@link #inception(String, int, int, boolean)}
* {@link #mobileNetV1(String, int, int, boolean)}
* {@link #mobileNetV2(String, int, int, boolean)}
* @param modelUri location of the pre-trained model to use.
* @param labelsUri location of the list fromMemory pre-trained categories used by the
* model.
* @param imageRecognitionGraphInputName name of the Model's input node to send the
* input image to.
* @param imageRecognitionGraphOutputName name of the Model's output node to retrieve
* the predictions from.
* @param imageHeight normalized image height.
* @param imageWidth normalized image width.
* @param mean mean value to normalize the input image.
* @param scale scale to normalize the input image.
* @param responseSize Max number of predictions per recognize.
* @param cacheModel if true the pre-trained model is cached on the local file system.
*/
public ImageRecognition(String modelUri, String labelsUri, int imageHeight, int imageWidth, float mean, float scale,
String imageRecognitionGraphInputName, String imageRecognitionGraphOutputName, int responseSize,
boolean cacheModel) {
this.labels = labels(labelsUri);
/**
* Normalizes the raw input image into format expected by the pre-trained
* Inception/MobileNetV1/MobileNetV2 models. Typically the model is trained
* fromMemory images scaled to certain size. Usually it is 224x224 pixels, but can
* be also 192x192, 160x160, 128128, 92x92. Use the (imageHeight, imageWidth) to
* set the desired size. The colors, represented as R, G, B in 1-byte each were
* converted to float using (Value - Mean)/Scale.
*
* imageHeight normalized image height. imageWidth normalized image width. mean
* mean value to normalize the input image. scale scale to normalize the input
* image.
*/
this.imageNormalization = new GraphRunner("raw_image", "normalized_image").withGraphDefinition(tf -> {
Placeholder<String> input = tf.withName("raw_image").placeholder(String.class);
final Operand<Float> decodedImage = tf.dtypes.cast(tf.image.decodeJpeg(input, DecodeJpeg.channels(3L)),
Float.class);
final Operand<Float> resizedImage = tf.image.resizeBilinear(tf.expandDims(decodedImage, tf.constant(0)),
tf.constant(new int[] { imageHeight, imageWidth }));
tf.withName("normalized_image").math.div(tf.math.sub(resizedImage, tf.constant(mean)), tf.constant(scale));
});
this.imageRecognition = new GraphRunner(imageRecognitionGraphInputName, imageRecognitionGraphOutputName)
.withGraphDefinition(new ProtoBufGraphDefinition(toResource(modelUri), cacheModel));
this.maxProbability = new GraphRunner(Arrays.asList("recognition_result"),
Arrays.asList("category", "probability"))
.withGraphDefinition(tf -> {
Placeholder<Float> input = tf.withName("recognition_result").placeholder(Float.class);
tf.withName("category").math.argMax(input, tf.constant(1));
tf.withName("probability").max(input, tf.constant(1));
});
this.topKProbabilities = new GraphRunner("recognition_result", "topK").withGraphDefinition(tf -> {
Placeholder<Float> input = tf.withName("recognition_result").placeholder(Float.class);
tf.withName("topK").nn.topK(input, tf.constant(responseSize), TopK.sorted(true));
});
}
/**
* Takes an byte encoded image and returns the most probable category recognized in
* the image along fromMemory its probability.
* @param inputImage Byte array encoded image to recognize.
* @return Returns a single map entry containing the names of the recognized
* categories as key and the confidence as value.
*/
public Map<String, Double> recognizeMax(byte[] inputImage) {
try (Tensor inputTensor = Tensor.create(inputImage); GraphRunnerMemory memorize = new GraphRunnerMemory()) {
Map<String, Tensor<?>> max = this.imageNormalization.andThen(memorize)
.andThen(this.imageRecognition)
.andThen(memorize)
.andThen(this.maxProbability)
.andThen(memorize)
.apply(Collections.singletonMap("raw_image", inputTensor));
long[] category = new long[1];
max.get("category").copyTo(category);
float[] probability = new float[1];
max.get("probability").copyTo(probability);
return Collections.singletonMap(labels.get((int) category[0]), Double.valueOf(probability[0]));
}
}
/**
* Takes an byte encoded input image and returns the top K most probable categories
* recognized in the image along fromMemory their probabilities.
* @param inputImage Byte array encoded image to recognize.
* @return Returns a list of key-value pairs. Every key-value pair represents a single
* category recognized. The key stands for the name(s) of the category while the value
* states the confidence that there is an object of this category. The entries in the
* Map are ordered from the higher to the lower confidences.
*/
public Map<String, Double> recognizeTopK(byte[] inputImage) {
try (Tensor inputTensor = Tensor.create(inputImage); GraphRunnerMemory memorize = new GraphRunnerMemory()) {
Map<String, Tensor<?>> topKResults = this.imageNormalization.andThen(memorize)
.andThen(this.imageRecognition)
.andThen(memorize)
.andThen(this.topKProbabilities)
.andThen(memorize)
.apply(Collections.singletonMap("raw_image", inputTensor));
Tensor recognizedImagesTensor = memorize.getTensorMap().get(this.imageRecognition.getSingleFetchName());
float[][] results = new float[(int) recognizedImagesTensor.shape()[0]][(int) recognizedImagesTensor
.shape()[1]];
recognizedImagesTensor.copyTo(results);
Tensor<Float> topKTensor = topKResults.get("topK").expect(Float.class);
float[][] topK = new float[(int) topKTensor.shape()[0]][(int) topKTensor.shape()[1]];
topKTensor.copyTo(topK);
float min = topK[0][topK[0].length - 1];
Map<Float, Integer> valueToIndex = new HashMap<>();
for (int i = 0; i < results[0].length; i++) {
if (results[0][i] >= min) {
valueToIndex.put(results[0][i], i);
}
}
Map<String, Double> map = new LinkedHashMap<>();
for (float tk : topK[0]) {
map.put(labels.get(valueToIndex.get(tk)), (double) tk);
}
return map;
}
}
private Resource toResource(String uri) {
return new DefaultResourceLoader().getResource(uri);
}
/**
* Converts a labels resources into string list.
* @return Returns string lists. One line per different category.
*/
private List<String> labels(String labelsUri) {
try (InputStream is = toResource(labelsUri).getInputStream()) {
return Arrays.asList(StreamUtils.copyToString(is, Charset.forName("UTF-8")).split("\n"));
}
catch (IOException e) {
throw new RuntimeException("Failed to initialize the Vocabulary", e);
}
}
/**
*
* The Inception graph uses "input" as input and "output" as output.
*
*/
public static ImageRecognition inception(String inceptionModelUri, int normalizedImageSize, int responseSize,
boolean cacheModel) {
return new ImageRecognition(inceptionModelUri, "classpath:/labels/inception_labels.txt", normalizedImageSize,
normalizedImageSize, 117f, 1f, "input", "output", responseSize, cacheModel);
}
/**
* Convenience for MobileNetV2 pre-trained models:
* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet#pretrained-models
*
* The normalized image size is always square (e.g. H=W)
*
* The MobileNetV2 graph uses "input" as input and "MobilenetV2/Predictions/Reshape_1"
* as output.
* @param mobileNetV2ModelUri model uri
* @param normalizedImageSize Depends on the pre-trained model used. Usually 224px is
* used.
* @param responseSize Number of responses fot topK requests.
* @param cacheModel cache model
* @return ImageRecognition instance configured fromMemory a MobileNetV2 pre-trained
* model.
*/
public static ImageRecognition mobileNetV2(String mobileNetV2ModelUri, int normalizedImageSize, int responseSize,
boolean cacheModel) {
return new ImageRecognition(mobileNetV2ModelUri, "classpath:/labels/mobilenet_labels.txt", normalizedImageSize,
normalizedImageSize, 0f, 127f, "input", "MobilenetV2/Predictions/Reshape_1", responseSize, cacheModel);
}
/**
* Convenience for MobileNetV1 pre-trained models:
* https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md#pre-trained-models
*
* The MobileNetV1 graph uses "input" as input and "MobilenetV1/Predictions/Reshape_1"
* as output.
*
*/
public static ImageRecognition mobileNetV1(String mobileNetV1ModelUri, int normalizedImageSize, int responseSize,
boolean cacheModel) {
return new ImageRecognition(mobileNetV1ModelUri, "classpath:/labels/mobilenet_labels.txt", normalizedImageSize,
normalizedImageSize, 0f, 127f, "input", "MobilenetV1/Predictions/Reshape_1", responseSize, cacheModel);
}
/**
* Convert image recognition results into {@link RecognitionResponse} domain list.
* @param recognitionMap map containing the category mames and its probability.
* Returned by the {@link ImageRecognition#recognizeMax(byte[])} and the
* ImageRecognition{@link #recognizeTopK(byte[])} methods
* @return List of {@link RecognitionResponse} objects representing the
* name-to-probability pairs in the input map.
*/
public static List<RecognitionResponse> toRecognitionResponse(Map<String, Double> recognitionMap) {
return recognitionMap.entrySet()
.stream()
.map(nameProbabilityPair -> new RecognitionResponse(nameProbabilityPair.getKey(),
nameProbabilityPair.getValue()))
.collect(Collectors.toList());
}
@Override
public void close() {
this.imageNormalization.close();
this.imageRecognition.close();
this.maxProbability.close();
this.topKProbabilities.close();
}
}

View File

@@ -1,103 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.image.recognition;
import java.awt.Color;
import java.awt.FontMetrics;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.geom.Rectangle2D;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.List;
import java.util.function.BiFunction;
import javax.imageio.ImageIO;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Ability to to augment the input image fromMemory the recognized labels.
*
* @author Christian Tzolov
*/
public class ImageRecognitionAugmenter implements BiFunction<byte[], List<RecognitionResponse>, byte[]> {
private static final Log logger = LogFactory.getLog(ImageRecognitionAugmenter.class);
/** IMAGE_FORMAT. */
public static final String IMAGE_FORMAT = "jpg";
private final Color textColor = Color.BLACK;
private final Color bgColor = new Color(167, 252, 0);
public ImageRecognitionAugmenter() {
}
/**
* Augment the input image by adding the recognized classes.
* @param imageBytes input image as byte array
* @param result computed recognition labels
* @return the image augmented fromMemory recognized labels.
*/
@Override
public byte[] apply(byte[] imageBytes, List<RecognitionResponse> result) {
try {
if (result != null) {
BufferedImage originalImage = ImageIO.read(new ByteArrayInputStream(imageBytes));
Graphics2D g = originalImage.createGraphics();
g.setRenderingHint(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_ON);
FontMetrics fm = g.getFontMetrics();
int x = 1;
int y = 1;
for (RecognitionResponse r : result) {
String labelName = r.getLabel();
int probability = (int) (100 * r.getProbability());
String title = labelName + ": " + probability + "%";
Rectangle2D rect = fm.getStringBounds(title, g);
g.setColor(bgColor);
g.fillRect(x, y, (int) rect.getWidth() + 6, (int) rect.getHeight());
g.setColor(textColor);
g.drawString(title, x + 3, (int) (y + rect.getHeight() - 3));
y = (int) (y + rect.getHeight() + 1);
}
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ImageIO.write(originalImage, IMAGE_FORMAT, baos);
baos.flush();
imageBytes = baos.toByteArray();
baos.close();
}
}
catch (IOException e) {
logger.error("Failed to draw labels in the input image", e);
}
return imageBytes;
}
}

View File

@@ -1,57 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.image.recognition;
/**
* @author Christian Tzolov
*/
public class RecognitionResponse {
private String label;
private Double probability;
public RecognitionResponse() {
}
public RecognitionResponse(String label, Double probability) {
this.label = label;
this.probability = probability;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
public Double getProbability() {
return probability;
}
public void setProbability(Double probability) {
this.probability = probability;
}
@Override
public String toString() {
return "{label='" + label + ", probability=" + probability + '}';
}
}

View File

@@ -1,111 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.image.recognition.util;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
/**
* Create a text file mapping label id to human readable string.
*
* Produces a text file where every line represents single category. The line number
* represents the category id, while the line text is human-readable names for the
* categories fromMemory this imagenet id.
*
* Based on
* https://github.com/tensorflow/models/blob/master/research/slim/datasets/imagenet.py#L66
*
* We retrieve a synset file, which contains a list of valid synset labels used by ILSVRC
* competition. There is one synset one per line, eg. # n01440764 # n01443537 We also
* retrieve a synset_to_human_file, which contains a mapping from synsets to
* human-readable names for every synset in Imagenet. These are stored in a tsv format, as
* follows: # n02119247 black fox # n02119359 silver fox We assign each synset (in
* alphabetical order) an integer, starting from 1 (since 0 is reserved for the background
* class)
*
* @author Christian Tzolov
*/
public final class ImageNetReadableNamesWriter {
private ImageNetReadableNamesWriter() {
}
/** BASE_URL. */
public final static String BASE_URL = "https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/";
/** SYNSET_URI. */
public final static String SYNSET_URI = BASE_URL + "imagenet_lsvrc_2015_synsets.txt";
/** SYNSET_TO_HUMAN_URI. */
public final static String SYNSET_TO_HUMAN_URI = BASE_URL + "imagenet_metadata.txt";
public static void main(String[] args) {
Charset utf8 = Charset.forName("UTF-8");
try (InputStream synsetIs = toResource(SYNSET_URI).getInputStream();
InputStream synsetToHumanIs = toResource(SYNSET_TO_HUMAN_URI).getInputStream()) {
List<String> synsetList = Arrays.asList(StreamUtils.copyToString(synsetIs, utf8).split("\n"))
.stream()
.map(l -> l.trim())
.collect(Collectors.toList());
Assert.notNull(synsetList, "Failed to initialize the labels list");
Assert.isTrue(synsetList.size() == 1000,
"Labels list is expected to be of " + "size 1000 but was:" + synsetList.size());
Map<String, String> synsetToHuman = Arrays
.asList(StreamUtils.copyToString(synsetToHumanIs, utf8).split("\n"))
.stream()
.map(s2h -> s2h.split("\t"))
.collect(Collectors.toMap(s -> s[0], s -> s[1]));
Assert.notNull(synsetToHuman, "Failed to initialize the synsetToHuman");
Assert.isTrue(synsetToHuman.size() == 21842,
"synsetToHuman is expected to be of " + "size 21842 but was:" + synsetToHuman.size());
List<String> l = synsetList.stream().map(id -> synsetToHuman.get(id)).collect(Collectors.toList());
List<String> ll = new ArrayList<>();
ll.add("dummy");
ll.addAll(l);
System.out.println(ll.get(389));
FileUtils.writeLines(new File("labels.txt"), ll);
}
catch (IOException e) {
throw new RuntimeException("Failed to initialize the Vocabulary", e);
}
}
public static Resource toResource(String uri) {
return new DefaultResourceLoader().getResource(uri);
}
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 17 KiB

View File

@@ -1,92 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.image.recognition;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.JsonMapperFunction;
/**
* @author Christian Tzolov
*/
public final class ImageRecognitionExample {
private ImageRecognitionExample() {
}
public static void main(String[] args) throws IOException {
// You can use file:, http: or classpath: to provide the path to the input image.
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/giant_panda_in_beijing_zoo_1.jpg");
// MmobileNetV2 models
// https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet#pretrained-models
String mobilenet_v2_modelUri = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz#mobilenet_v2_1.4_224_frozen.pb";
// String mobilenet_v2_modelUri =
// "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz#mobilenet_v2_0.35_96_frozen.pb";
try (ImageRecognition imageRecognition = ImageRecognition.mobileNetV2(mobilenet_v2_modelUri, 224, 5, true)) {
List<RecognitionResponse> recognizedObjects = ImageRecognition
.toRecognitionResponse(imageRecognition.recognizeTopK(inputImage));
// Draw the predicted labels on top of the input image.
byte[] augmentedImage = new ImageRecognitionAugmenter().apply(inputImage, recognizedObjects);
IOUtils.write(augmentedImage,
new FileOutputStream("./image-recognition/target/image-augmented-mobilnetV2.jpg"));
String jsonRecognizedObjects = new JsonMapperFunction().apply(recognizedObjects);
System.out.println("mobilnetV2 result:" + jsonRecognizedObjects);
}
String mobilenet_v1_modelUri = "https://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz#mobilenet_v1_1.0_224_frozen.pb";
try (ImageRecognition recognitionService = ImageRecognition.mobileNetV1(mobilenet_v1_modelUri, 224, 5, true)) {
List<RecognitionResponse> recognizedObjects = ImageRecognition
.toRecognitionResponse(recognitionService.recognizeTopK(inputImage));
// Draw the predicted labels on top of the input image.
byte[] augmentedImage = new ImageRecognitionAugmenter().apply(inputImage, recognizedObjects);
IOUtils.write(augmentedImage,
new FileOutputStream("./image-recognition/target/image-augmented-mobilnetV1.jpg"));
String jsonRecognizedObjects = new JsonMapperFunction().apply(recognizedObjects);
System.out.println("mobilnetV1 result:" + jsonRecognizedObjects);
}
String inception_modelUri = "https://storage.googleapis.com/scdf-tensorflow-models/image-recognition/tensorflow_inception_graph.pb";
try (ImageRecognition recognitionService = ImageRecognition.inception(inception_modelUri, 224, 5, true)) {
List<RecognitionResponse> recognizedObjects = ImageRecognition
.toRecognitionResponse(recognitionService.recognizeTopK(inputImage));
// Draw the predicted labels on top of the input image.
byte[] augmentedImage = new ImageRecognitionAugmenter().apply(inputImage, recognizedObjects);
IOUtils.write(augmentedImage,
new FileOutputStream("./image-recognition/target/image-augmented-inception.jpg"));
String jsonRecognizedObjects = new JsonMapperFunction().apply(recognizedObjects);
System.out.println("inception result:" + jsonRecognizedObjects);
}
}
}

View File

@@ -1,80 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.image.recognition;
import java.io.FileOutputStream;
import java.io.IOException;
import org.apache.commons.io.IOUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
/**
* @author Christian Tzolov
*/
public final class ImageRecognitionExample2 {
private ImageRecognitionExample2() {
}
public static void main(String[] args) throws IOException {
ImageRecognitionAugmenter augmenter = new ImageRecognitionAugmenter();
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/giant_panda_in_beijing_zoo_1.jpg");
ImageRecognition inceptions = ImageRecognition.inception(
"https://storage.googleapis.com/scdf-tensorflow-models/image-recognition/tensorflow_inception_graph.pb",
224, 10, true);
System.out.println(inceptions.recognizeMax(inputImage));
System.out.println(inceptions.recognizeTopK(inputImage));
System.out.println(ImageRecognition.toRecognitionResponse(inceptions.recognizeTopK(inputImage)));
IOUtils.write(
augmenter.apply(inputImage,
ImageRecognition.toRecognitionResponse(inceptions.recognizeTopK(inputImage))),
new FileOutputStream(
"./functions/function/image-recognition-function/target/image-augmented-inceptions.jpg"));
inceptions.close();
ImageRecognition mobileNetV2 = ImageRecognition.mobileNetV2(
"https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz#mobilenet_v2_1.4_224_frozen.pb",
224, 10, true);
System.out.println(mobileNetV2.recognizeMax(inputImage));
System.out.println(mobileNetV2.recognizeTopK(inputImage));
IOUtils.write(
augmenter.apply(inputImage,
ImageRecognition.toRecognitionResponse(mobileNetV2.recognizeTopK(inputImage))),
new FileOutputStream(
"./functions/function/image-recognition-function/target/image-augmented-mobilnetV2.jpg"));
mobileNetV2.close();
ImageRecognition mobileNetV1 = ImageRecognition.mobileNetV1(
"https://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz#mobilenet_v1_1.0_224_frozen.pb",
224, 10, true);
System.out.println(mobileNetV1.recognizeMax(inputImage));
System.out.println(mobileNetV1.recognizeTopK(inputImage));
IOUtils.write(
augmenter.apply(inputImage,
ImageRecognition.toRecognitionResponse(mobileNetV1.recognizeTopK(inputImage))),
new FileOutputStream(
"./functions/function/image-recognition-function/target/image-augmented-mobilnetV1.jpg"));
mobileNetV1.close();
}
}

View File

@@ -1,65 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.image.recognition;
import java.util.Map;
import com.google.protobuf.InvalidProtocolBufferException;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
/**
* @author Christian Tzolov
*/
public final class SavedModelTest {
private SavedModelTest() {
}
/**
* https://medium.com/@jsflo.dev/saving-and-loading-a-tensorflow-model-using-the-savedmodel-api-17645576527
*
* https://www.tensorflow.org/alpha/guide/saved_model
*
*/
public static void main(String[] args) throws InvalidProtocolBufferException {
SavedModelBundle savedModelBundle = SavedModelBundle
.load("/Users/ctzolov/Downloads/ssd_mobilenet_v1_coco_2017_11_17/saved_model", "serve");
// SavedModelBundle.load("/Users/ctzolov/Downloads/aiy_vision_classifier_plants_V1_1/",
// "serve");
// SavedModelBundle savedModelBundle =
// SavedModelBundle.load("/Users/ctzolov/Downloads/mnasnet-a1/saved_model",
// "serve");
MetaGraphDef meta = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef());
Map<String, SignatureDef> signatures = meta.getSignatureDefMap();
System.out.println(signatures);
savedModelBundle.session();
// Iterator<Operation> itr = savedModelBundle.graph().operations();
//
// while (itr.hasNext()) {
// System.out.println("Operation: " + itr.next());
// }
}
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

View File

@@ -1,189 +0,0 @@
:images-asciidoc: https://raw.githubusercontent.com/spring-cloud/stream-applications/master/functions/function/object-detection-function/src/main/resources/images/
# Object Detection Function
Java model inference library for the https://github.com/tensorflow/models/blob/master/research/object_detection/README.md[TensorFlow Object Detection API]. Allows real-time localization and identification of multiple objects in a single or batch of images. Works with all https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md[pre-trained zoo models] and ttps://github.com/tensorflow/models/tree/865c14c/research/object_detection/data[object labels].
[cols="1,2", frame=none, grid=none]
|===
| image:{images-asciidoc}/object_detection_1.jpg[alt=Object Detection 1, width=100%]
|The https://github.com/spring-cloud/stream-applications/blob/master/functions/function/object-detection-function/src/main/java/org/springframework/cloud/fn/object/detection/ObjectDetectionService.java[ObjectDetectionService]
takes an image or a batch of images and outputs a list of predicted objects bounding boxes
represented by https://github.com/spring-cloud/stream-applications/blob/master/functions/function/object-detection-function/src/main/java/org/springframework/cloud/fn/object/detection/domain/ObjectDetection.java[ObjectDetection].
For the models supporting https://github.com/tensorflow/models/tree/master/research/object_detection#february-9-2018[Instance Segmentation],
the `ObjectDetectionService` can predict the instance segmentation `masks` in addition to object bounding boxes.
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/deprecated/JsonMapperFunction.java[JsonMapperFunction] permits
converting the `List<ObjectDetection>` into JSON objects and the
https://github.com/spring-cloud/stream-applications/blob/master/functions/function/object-detection-function/src/main/java/org/springframework/cloud/fn/object/detection/ObjectDetectionImageAugmenter.java[ObjectDetectionImageAugmenter]
allow to augment the input image with the detected bounding boxes and segmentation masks.
|===
## Usage
Add the `object-detection` dependency to the pom (use the latest version available):
[source,xml]
----
<dependency>
<groupId>org.springframework.cloud.fn</groupId>
<artifactId>object-detection-function</artifactId>
<version>${revision}</version>
</dependency>
----
#### Example 1: Object Detection
The https://github.com/spring-cloud/stream-applications/blob/master/functions/function/object-detection-function/src/test/java/org/springframework/cloud/fn/object/detection/examples/ExampleObjectDetection.java[ExampleObjectDetection.java]
sample demonstrates how to use the `ObjectDetectionService` for detecting objects in input images. It also shows how to
convert the result into JSON format and augment the input image with the detected object bounding boxes.
[source,java,linenums]
----
ObjectDetectionService detectionService = new ObjectDetectionService(
"https://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz#frozen_inference_graph.pb", //<1>
"https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt", //<2>
0.4f, //<3>
false, //<4>
true); //<5>
byte[] image = GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg"); //<6>
List<ObjectDetection> detectedObjects = detectionService.detect(image); //<7>
----
<1> Downloads and loads a pre-trained `frozen_inference_graph.pb` model directly from the `faster_rcnn_nas_coco.tar.gz` archive in the
Tensorflow model zoo. Mind that on first attempt it will download few hundreds of MBs. The consecutive runs will use the
cached copy (5) instead.
<2> Object category labels (e.g. names) for the model
<3> Confidence threshold - Only object with estimate above the threshold are returned
<4> Indicate that this is not a `mask` (e.g. not an instance segmentation) model type
<5> Cache the model on the local file system.
<6> Load the input image to evaluate
<7> Detect the objects in the image and represent the result as a list of ObjectDetection instances.
Next you can convert the result in JSON format.
[source,java,linenums]
----
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects);
System.out.println(jsonObjectDetections);
----
.Sample Object Detection JSON representation
[source,json]
----
[{"name":"person","estimate":0.998,"x1":0.160,"y1":0.774,"x2":0.201,"y2":0.946,"cid":1},
{"name":"kite","estimate":0.998,"x1":0.437,"y1":0.089,"x2":0.495,"y2":0.169,"cid":38},
{"name":"person","estimate":0.997,"x1":0.084,"y1":0.681,"x2":0.121,"y2":0.848,"cid":1},
{"name":"kite","estimate":0.988,"x1":0.206,"y1":0.263,"x2":0.225,"y2":0.314,"cid":38}]]
----
Use the https://github.com/spring-cloud/stream-applications/blob/master/functions/function/object-detection-function/src/main/java/org/springframework/cloud/fn/object/detection/ObjectDetectionImageAugmenter.java[ObjectDetectionImageAugmenter]
to draw the detected objects on top of the input image.
[source,java,linenums]
----
byte[] annotatedImage = new ObjectDetectionImageAugmenter().apply(image, detectedObjects); // <1>
IOUtils.write(annotatedImage, new FileOutputStream("./object-detection-function/target/object-detection-augmented.jpg")); //<2>
----
<1> Augment the image with the detected object bounding boxes (Uses Java2D internally).
<2> Stores the augmented image as `object-detection-augmented.jpg` image file.
.Augmented object-detection-augmented.jpg file
image:{images-asciidoc}/object-detection-augmented.jpg[alt=Object Detection, width=60%]
TIP: Set the `ObjectDetectionImageAugmenter#agnosticColors` property to `true` to use a monochrome color schema.
#### Example 2: Instance Segmentation
The https://github.com/spring-cloud/stream-applications/blob/master/functions/function/object-detection-function/src/test/java/org/springframework/cloud/fn/object/detection/examples/ExampleInstanceSegmentation.java[ExampleInstanceSegmentation.java]
sample shows how to use the `ObjectDetectionService` for `Instance Segmentation`.
NOTE: It requires a trained model that supports `Masks` as well as setting the instance segmentation (e.g. `useMasks`) flag to `true`.
[source,java,linenums]
----
ObjectDetectionService detectionService = new ObjectDetectionService(
"https://download.tensorflow.org/models/object_detection/mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28.tar.gz#frozen_inference_graph.pb", // <1>
"https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt", // <2>
0.4f, // <3>
true, // <4>
true); // <5>
byte[] image = GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg");
List<ObjectDetection> detectedObjects = detectionService.detect(image); // <6>
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects); // <7>
System.out.println(jsonObjectDetections);
byte[] annotatedImage = new ObjectDetectionImageAugmenter(true) // <8>
.apply(image, detectedObjects);
IOUtils.write(annotatedImage, new FileOutputStream("./object-detection-function/target/object-detection-segmentation-augmented.jpg"));
----
<1> Uses one of the 4 MASK pre-trained models
<2> Object category labels (e.g. names) for the model
<3> Confidence threshold - Only object with estimate above the threshold are returned.
<4> Use masks output - For the pre-trained models instruct to use the extended fetch names that include instance segmentation masks as well.
<5> Cache model - Create a local copy of the model to speed up consecutive runs.
<6> Evaluate the model to predict the object in the input image.
<7> Convert the detected object in to JSON array. NOTE: that with mask there is an additional field: `mask`
<8> Draw the detected object on top of the input image. Mind the `true` constructor parameter stands for draw detected masks.
If false only the bounding boxes will be shown.
.Result augmented object-detection-segmentation-augmented.jpg file
image:{images-asciidoc}/object-detection-segmentation-augmented.jpg[alt=Object Detection Augmented, width=60%]
## Models
All pre-trained https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md[detection_model_zoo.md]
models are supported. Following URI notation can be used to download any of the models directly from the zoo.
----
http://<zoo model tar.gz url>#frozen_inference_graph.pb
----
The `frozen_inference_graph.pb` is the frozen model file name within the archive.
NOTE: For some models this name may differ. You have to download and open the archive to find the real name.
TIP: To speedup the bootstrap performance you may consider extracting the `frozen_inference_graph.pb` and caching it
locally. Then you can use the `file://path-to-my-local-copy` URI schema to access it.
Following models can be used for `Instance Segmentation` as well:
[frame=none, grid=none]
|===
| https://download.tensorflow.org/models/object_detection/mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28.tar.gz[mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28.tar.gz]
| https://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz[mask_rcnn_inception_v2_coco_2018_01_28.tar.gz]
| https://download.tensorflow.org/models/object_detection/mask_rcnn_resnet101_atrous_coco_2018_01_28.tar.gz[mask_rcnn_resnet101_atrous_coco_2018_01_28.tar.gz]
| https://download.tensorflow.org/models/object_detection/mask_rcnn_resnet50_atrous_coco_2018_01_28.tar.gz[mask_rcnn_resnet50_atrous_coco_2018_01_28.tar.gz]
|===
In addition to the model, the `ObjectDetectionService` requires a list of labels that correspond to the categories detectable by the selected model.
All labels files are available in the https://github.com/tensorflow/models/tree/master/research/object_detection/data[object_detection/data] folder.
NOTE: It is important to use the labels that correspond to the model being used! Table below highlights this mapping.
.Relationsip between trained model types and category labels
[%header, cols="1,2", frame=none, grid=none]
|===
| Model
| Labels
| https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#coco-trained-models[coco]
| https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt[mscoco_label_map.pbtxt]
| https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#kitti-trained-models[kitti]
| https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/kitti_label_map.pbtxt[kitti_label_map.pbtxt]
| https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#open-images-trained-models[open-images]
| https://github.com/tensorflow/models/blob/master/research/object_detection/data/oid_bbox_trainable_label_map.pbtxt[oid_bbox_trainable_label_map.pbtxt]
| https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#inaturalist-species-trained-models[inaturalist-species]
| https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/fgvc_2854_classes_label_map.pbtxt[fgvc_2854_classes_label_map.pbtxt]
| https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#ava-v21-trained-models[ava]
| https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/ava_label_map_v2.1.pbtxt[ava_label_map_v2.1.pbtxt]
|===
TIP: For performance reasons you may consider downloading the required label files to the local file system.

View File

@@ -1,12 +0,0 @@
apply plugin: 'com.google.protobuf'
dependencies {
api project(':spring-tensorflow-common')
api protobufJava
}
protobuf {
protoc {
artifact = "com.google.protobuf:protoc:$protobufVersion"
}
}

View File

@@ -1,120 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.List;
import java.util.function.BiFunction;
import javax.imageio.ImageIO;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.util.CollectionUtils;
/**
* Augment the input image fromMemory detected object bounding boxes and categories. For
* mask models and withMask set to true it draws the instance segmentation image as well.
*
* @author Christian Tzolov
*/
public class ObjectDetectionImageAugmenter implements BiFunction<byte[], List<ObjectDetection>, byte[]> {
private static final Log logger = LogFactory.getLog(ObjectDetectionImageAugmenter.class);
/** Make checkstyle happy. **/
public static final String DEFAULT_IMAGE_FORMAT = "jpg";
private String imageFormat = DEFAULT_IMAGE_FORMAT;
private final boolean withMask;
private boolean agnosticColors = false;
public ObjectDetectionImageAugmenter() {
this(false);
}
public ObjectDetectionImageAugmenter(boolean withMask) {
this.withMask = withMask;
}
public boolean isAgnosticColors() {
return agnosticColors;
}
public void setAgnosticColors(boolean agnosticColors) {
this.agnosticColors = agnosticColors;
}
public String getImageFormat() {
return imageFormat;
}
public void setImageFormat(String imageFormat) {
this.imageFormat = imageFormat;
}
@Override
public byte[] apply(byte[] imageBytes, List<ObjectDetection> objectDetections) {
if (!CollectionUtils.isEmpty(objectDetections)) {
try {
BufferedImage bufferedImage = ImageIO.read(new ByteArrayInputStream(imageBytes));
for (ObjectDetection od : objectDetections) {
int y1 = (int) (od.getY1() * (float) bufferedImage.getHeight());
int x1 = (int) (od.getX1() * (float) bufferedImage.getWidth());
int y2 = (int) (od.getY2() * (float) bufferedImage.getHeight());
int x2 = (int) (od.getX2() * (float) bufferedImage.getWidth());
int cid = od.getCid();
String labelName = od.getName();
int probability = (int) (100 * od.getConfidence());
String title = labelName + ": " + probability + "%";
GraphicsUtils.drawBoundingBox(bufferedImage, cid, title, x1, y1, x2, y2, this.agnosticColors);
if (this.withMask && od.getMask() != null) {
float[][] mask = od.getMask();
if (mask != null) {
Color maskColor = this.agnosticColors ? null : GraphicsUtils.getClassColor(cid);
BufferedImage maskImage = GraphicsUtils.createMaskImage(mask, x2 - x1, y2 - y1, maskColor);
GraphicsUtils.overlayImages(bufferedImage, maskImage, x1, y1);
}
}
}
imageBytes = GraphicsUtils.toImageByteArray(bufferedImage, this.getImageFormat());
}
catch (IOException e) {
logger.error(e);
}
}
// Null mend that QR image is found and not output message will be send.
return imageBytes;
}
}

View File

@@ -1,79 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.util.Collections;
import java.util.Map;
import java.util.function.Function;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.image.DecodeJpeg;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
/**
* Converts byte array image into a input Tensor for the Object Detection API.
*
* @author Christian Tzolov
*/
public class ObjectDetectionInputAdapter implements Function<byte[], Map<String, Tensor<?>>>, AutoCloseable {
private static final Log logger = LogFactory.getLog(ObjectDetectionInputAdapter.class);
/** Make checkstyle happy. **/
public static final String RAW_IMAGE = "raw_image";
/** Make checkstyle happy. **/
public static final String NORMALIZED_IMAGE = "normalized_image";
/** Make checkstyle happy. **/
public static final long CHANNELS = 3;
private final GraphRunner imageLoaderGraph;
public ObjectDetectionInputAdapter() {
this.imageLoaderGraph = new GraphRunner(RAW_IMAGE, NORMALIZED_IMAGE).withGraphDefinition(tf -> {
Placeholder<String> rawImage = tf.withName(RAW_IMAGE).placeholder(String.class);
Operand<UInt8> decodedImage = tf.dtypes.cast(tf.image.decodeJpeg(rawImage, DecodeJpeg.channels(CHANNELS)),
UInt8.class);
// Expand dimensions since the model expects images to have shape: [1, H, W,
// 3]
tf.withName(NORMALIZED_IMAGE).expandDims(decodedImage, tf.constant(0));
});
}
@Override
public Map<String, Tensor<?>> apply(byte[] inputImage) {
try (Tensor<?> inputTensor = Tensor.<Object>create(inputImage)) {
return this.imageLoaderGraph.apply(Collections.singletonMap(RAW_IMAGE, inputTensor));
}
}
@Override
public void close() {
if (this.imageLoaderGraph != null) {
this.imageLoaderGraph.close();
}
}
}

View File

@@ -1,99 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.Map;
import java.util.function.Function;
import javax.imageio.ImageIO;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
/**
* Converts byte array image into a input Tensor for the Object Detection API. The
* computed image tensors uses the 'image_tensor' model placeholder.
*
* @author Christian Tzolov
*/
public class ObjectDetectionInputConverter implements Function<byte[][], Map<String, Tensor<?>>> {
private static final Log logger = LogFactory.getLog(ObjectDetectionInputConverter.class);
private static final long CHANNELS = 3;
/** Make checkstyle happy. **/
public static final String IMAGE_TENSOR_FEED_NAME = "image_tensor";
@Override
public Map<String, Tensor<?>> apply(byte[][] input) {
return Collections.singletonMap(IMAGE_TENSOR_FEED_NAME, makeImageTensor(input));
}
private static Tensor<UInt8> makeImageTensor(byte[][] imageBytesArray) {
try {
int batchSize = imageBytesArray.length;
ByteBuffer byteBuffer = null;
long[] shape = null;
for (int batchIndex = 0; batchIndex < batchSize; batchIndex++) {
byte[] imageBytes = imageBytesArray[batchIndex];
ByteArrayInputStream is = new ByteArrayInputStream(imageBytes);
BufferedImage img = ImageIO.read(is);
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
img = GraphicsUtils.toBufferedImageType(img, BufferedImage.TYPE_3BYTE_BGR);
}
if (byteBuffer == null) {
byteBuffer = ByteBuffer.allocate((int) (batchSize * img.getHeight() * img.getWidth() * CHANNELS));
shape = new long[] { batchSize, img.getHeight(), img.getWidth(), CHANNELS };
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read produces BGR-encoded images, while the model expects RGB.
bgrToRgb(data);
byteBuffer.put(data);
}
byteBuffer.flip();
return Tensor.create(UInt8.class, shape, byteBuffer);
}
catch (IOException e) {
throw new IllegalArgumentException("Incorrect image format", e);
}
}
private static void bgrToRgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
}

View File

@@ -1,198 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import com.google.protobuf.TextFormat;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Tensor;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.cloud.fn.object.detection.protos.StringIntLabelMapOuterClass;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
/**
* Converts the Tensorflow Object Detection result into {@link ObjectDetection} list. The
* pre-trained Object Detection models (http://bit.ly/2osxMAY) produce 3 tensor outputs:
* (1) detection_classes - containing the ids of detected objects, (2) detection_scores -
* confidence probabilities of the detected object and (3) detection_boxes - the object
* bounding boxes withing the images.
*
* The MASK based models provide to 2 additional tensors: (4) num_detections and (5)
* detection_masks.
*
* All outputs tensors are float arrays, having: - 1 as the first dimension - maxObjects
* as the second dimension While boxesT will have 4 as the third dimension (2 sets of (x,
* y) coordinates). This can be verified by looking at scoresT.shape() etc.
*
* The format detected classes (e.g. labels) names is defined by the
* 'string_int_labels_map.proto'. The input list is available at:
* https://github.com/tensorflow/models/tree/master/research/object_detection/data
*
* @author Christian Tzolov
*/
public class ObjectDetectionOutputConverter implements Function<Map<String, Tensor<?>>, List<List<ObjectDetection>>> {
private static final Log logger = LogFactory.getLog(ObjectDetectionOutputConverter.class);
/** DETECTION_CLASSES. */
public static final String DETECTION_CLASSES = "detection_classes";
/** DETECTION_SCORES. */
public static final String DETECTION_SCORES = "detection_scores";
/** DETECTION_BOXES. */
public static final String DETECTION_BOXES = "detection_boxes";
/** DETECTION_MASKS. */
public static final String DETECTION_MASKS = "detection_masks";
/** NUM_DETECTIONS. */
public static final String NUM_DETECTIONS = "num_detections";
private final String[] labels;
private float confidence;
private List<String> modelFetch;
public ObjectDetectionOutputConverter(Resource labelsResource, float confidence, List<String> modelFetch) {
this.confidence = confidence;
this.modelFetch = modelFetch;
try {
this.labels = loadLabels(labelsResource);
Assert.notNull(this.labels, String.format("Failed to initialize object labels [%s].", labelsResource));
}
catch (Exception e) {
throw new RuntimeException(String.format("Failed to initialize object labels [%s].", labelsResource), e);
}
logger.info(String.format("Object labels [%s] loaded.", labelsResource));
}
/**
* Loads object labels in the string_int_label_map.proto.
* @param labelsResource location of the labels as a resource
* @return String[] of labels
*/
private static String[] loadLabels(Resource labelsResource) throws Exception {
try (InputStream is = labelsResource.getInputStream()) {
String text = StreamUtils.copyToString(is, Charset.forName("UTF-8"));
StringIntLabelMapOuterClass.StringIntLabelMap.Builder builder = StringIntLabelMapOuterClass.StringIntLabelMap
.newBuilder();
TextFormat.merge(text, builder);
StringIntLabelMapOuterClass.StringIntLabelMap proto = builder.build();
int maxLabelId = proto.getItemList()
.stream()
.map(StringIntLabelMapOuterClass.StringIntLabelMapItem::getId)
.max(Comparator.comparing(i -> i))
.orElse(-1);
String[] labelIdToNameMap = new String[maxLabelId + 1];
for (StringIntLabelMapOuterClass.StringIntLabelMapItem item : proto.getItemList()) {
if (StringUtils.hasText(item.getDisplayName())) {
labelIdToNameMap[item.getId()] = item.getDisplayName();
}
else {
// Common practice is to set the name to a MID or Synsets Id. Synset
// is a set of synonyms that
// share a common meaning: https://en.wikipedia.org/wiki/WordNet
labelIdToNameMap[item.getId()] = item.getName();
}
}
return labelIdToNameMap;
}
}
@Override
public List<List<ObjectDetection>> apply(Map<String, Tensor<?>> tensorMap) {
try (Tensor<Float> scoresTensor = tensorMap.get(DETECTION_SCORES).expect(Float.class);
Tensor<Float> classesTensor = tensorMap.get(DETECTION_CLASSES).expect(Float.class);
Tensor<Float> boxesTensor = tensorMap.get(DETECTION_BOXES).expect(Float.class)) {
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y)
// coordinates).
// This can be verified by looking at scoresT.shape() etc.
int batchSize = (int) scoresTensor.shape()[0];
int maxObjects = (int) scoresTensor.shape()[1];
float[][] scores = scoresTensor.copyTo(new float[batchSize][maxObjects]);
float[][] classes = classesTensor.copyTo(new float[batchSize][maxObjects]);
float[][][] boxes = boxesTensor.copyTo(new float[batchSize][maxObjects][4]);
List<List<ObjectDetection>> batchObjectDetections = new ArrayList<>();
for (int batchIndex = 0; batchIndex < batchSize; batchIndex++) {
List<ObjectDetection> objectDetections = new ArrayList<>();
// Collect only the objects whose scores are at above the configured
// confidence threshold.
for (int i = 0; i < scores[batchIndex].length; ++i) {
if (scores[batchIndex][i] >= confidence) {
String category = labels[(int) classes[batchIndex][i]];
float score = scores[batchIndex][i];
ObjectDetection od = new ObjectDetection();
od.setName(category);
od.setConfidence(score);
od.setX1(boxes[batchIndex][i][1]);
od.setY1(boxes[batchIndex][i][0]);
od.setX2(boxes[batchIndex][i][3]);
od.setY2(boxes[batchIndex][i][2]);
od.setCid((int) classes[batchIndex][i]);
// Mask allows image-segmentation
if (modelFetch.contains(DETECTION_MASKS) && modelFetch.contains(NUM_DETECTIONS)) {
Tensor<Float> masksTensor = tensorMap.get(DETECTION_MASKS).expect(Float.class);
Tensor<Float> numDetections = tensorMap.get(NUM_DETECTIONS).expect(Float.class);
float nd = numDetections.copyTo(new float[batchSize])[0];
if (masksTensor != null) {
long[] shape = masksTensor.shape();
float[][][][] masks = masksTensor
.copyTo(new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]);
od.setMask(masks[batchIndex][i]);
logger.debug(String.format("Num detections: %s, Masks: %s", nd, masks));
}
}
objectDetections.add(od);
}
}
batchObjectDetections.add(objectDetections);
}
return batchObjectDetections;
}
}
}

View File

@@ -1,167 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.TensorFlowService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.util.StreamUtils;
/**
* Convenience class that leverages the the {@link ObjectDetectionInputConverter},
* {@link ObjectDetectionOutputConverter} and {@link TensorFlowService} in combination
* fromMemory the Tensorflow Object Detection API
* (https://github.com/tensorflow/models/tree/master/research/object_detection) models for
* detection objects in input images.
*
* All pre-trained models
* (https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)
* and labels are supported.
*
* You can download pre-trained models directly from the zoo:
* https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
* Just use the URI notation: (zoo model tar.gz url)#(name of the frozen model file name).
* To speedup the bootstrap performance you should consider downloading the models locally
* and use the file:/"path to my model" URI instead!
*
* The object category labels for the pre-trained models are available at:
* https://github.com/tensorflow/models/tree/master/research/object_detection/data Use the
* labels applicable for the model. Also, for performance reasons you may consider to
* download the labels and load them from file: instead.
*
* @author Christian Tzolov
*/
public class ObjectDetectionService {
/** Default list of fetch names for Box models. */
public static List<String> FETCH_NAMES = Arrays.asList(ObjectDetectionOutputConverter.DETECTION_SCORES,
ObjectDetectionOutputConverter.DETECTION_CLASSES, ObjectDetectionOutputConverter.DETECTION_BOXES,
ObjectDetectionOutputConverter.NUM_DETECTIONS);
/** Default list of fetch names for mask supporting models. */
public static List<String> FETCH_NAMES_WITH_MASKS = Arrays.asList(ObjectDetectionOutputConverter.DETECTION_SCORES,
ObjectDetectionOutputConverter.DETECTION_CLASSES, ObjectDetectionOutputConverter.DETECTION_BOXES,
ObjectDetectionOutputConverter.DETECTION_MASKS, ObjectDetectionOutputConverter.NUM_DETECTIONS);
private final ObjectDetectionInputConverter inputConverter;
private final ObjectDetectionOutputConverter outputConverter;
private final TensorFlowService tensorFlowService;
public ObjectDetectionService() {
this("https://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz#frozen_inference_graph.pb",
"https://storage.googleapis.com/scdf-tensorflow-models/object-detection/mscoco_label_map.pbtxt", 0.4f,
false, true);
}
/**
* Convenience constructor that would initialize all necessary internal components.
* @param modelUri URI of the pre-trained, frozen Tensorflow model.
* @param labelsUri URI of the pre-trained category labels.
* @param confidence Confidence threshold. Only objects detected wth confidence above
* this threshold will be returned.
* @param withMasks If a Mask model is selected then you can use this flag to extract
* the instance segmentation masks as well.
*/
public ObjectDetectionService(String modelUri, String labelsUri, float confidence, boolean withMasks,
boolean cacheModel) {
this.inputConverter = new ObjectDetectionInputConverter();
List<String> fetchNames = withMasks ? FETCH_NAMES_WITH_MASKS : FETCH_NAMES;
this.outputConverter = new ObjectDetectionOutputConverter(new DefaultResourceLoader().getResource(labelsUri),
confidence, fetchNames);
this.tensorFlowService = new TensorFlowService(new DefaultResourceLoader().getResource(modelUri), fetchNames,
cacheModel);
}
/**
* Generic constructor thea allow the converter to be pre-configured before passed to
* the service.
* @param inputConverter Converter from byte array to object detection input image
* tensor
* @param outputConverter Covets the object detection output tensors into
* {@link ObjectDetection } list
* @param tensorFlowService Java tensorflow runner instance
*/
public ObjectDetectionService(ObjectDetectionInputConverter inputConverter,
ObjectDetectionOutputConverter outputConverter, TensorFlowService tensorFlowService) {
this.inputConverter = inputConverter;
this.outputConverter = outputConverter;
this.tensorFlowService = tensorFlowService;
}
/**
* Detects objects in a single input image identified by its URI.
* @param imageUri input image's URI
* @return Returns a list of {@link ObjectDetection} domain objects representing
* detected objects
*/
public List<ObjectDetection> detect(String imageUri) {
try (InputStream is = new DefaultResourceLoader().getResource(imageUri).getInputStream()) {
return this.detect(StreamUtils.copyToByteArray(is));
}
catch (IOException e) {
e.printStackTrace();
throw new IllegalStateException("Failed to detect the image:" + imageUri, e);
}
}
/**
* Detects objects in a single {@link BufferedImage}.
* @param image Input image to detect objects from.
* @param format Image format (e.g. jpg, png ...) to use when converting the buffer
* into byte array.
* @return Returns a list of {@link ObjectDetection} domain objects representing
* detected objects in the input image
*/
public List<ObjectDetection> detect(BufferedImage image, String format) {
return this.detect(GraphicsUtils.toImageByteArray(image, format));
}
/**
* Detects objects from a single input image encoded as byte array.
* @param image Input image encoded as byte array
* @return Returns a list of {@link ObjectDetection} domain objects representing
* detected objects in the input image
*/
public List<ObjectDetection> detect(byte[] image) {
return this.inputConverter.andThen(this.tensorFlowService)
.andThen(this.outputConverter)
.apply(new byte[][] { image })
.get(0);
}
/**
* Uses detects objects from a batch of input images encoded as byte array.
* @param batchedImages Batch of input images encoded as byte arrays. First dimension
* is the batch size and second the image bytes.
* @return Returns list of lists. For every input image in the batch a list of
* {@link ObjectDetection} domain objects representing detected objects in the input
* image.
*/
public List<List<ObjectDetection>> detect(byte[][] batchedImages) {
return this.inputConverter.andThen(this.tensorFlowService).andThen(this.outputConverter).apply(batchedImages);
}
}

View File

@@ -1,115 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.image.DecodeJpeg;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
import org.springframework.cloud.fn.common.tensorflow.GraphRunnerMemory;
import org.springframework.cloud.fn.common.tensorflow.ProtoBufGraphDefinition;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
/**
* @author Christian Tzolov
*/
public class ObjectDetectionService2 implements AutoCloseable {
/** Default Box models fetch names. */
public static List<String> FETCH_NAMES = Arrays.asList(ObjectDetectionOutputConverter.DETECTION_SCORES,
ObjectDetectionOutputConverter.DETECTION_CLASSES, ObjectDetectionOutputConverter.DETECTION_BOXES,
ObjectDetectionOutputConverter.NUM_DETECTIONS);
/** Default Models models fetch names. */
public static List<String> FETCH_NAMES_WITH_MASKS = Arrays.asList(ObjectDetectionOutputConverter.DETECTION_SCORES,
ObjectDetectionOutputConverter.DETECTION_CLASSES, ObjectDetectionOutputConverter.DETECTION_BOXES,
ObjectDetectionOutputConverter.DETECTION_MASKS, ObjectDetectionOutputConverter.NUM_DETECTIONS);
private final GraphRunner imageNormalization;
private final GraphRunner objectDetection;
private final ObjectDetectionOutputConverter outputConverter;
public ObjectDetectionService2(String modelUri, ObjectDetectionOutputConverter outputConverter) {
this.imageNormalization = new GraphRunner("raw_image", "normalized_image").withGraphDefinition(tf -> {
Placeholder<String> rawImage = tf.withName("raw_image").placeholder(String.class);
Operand<UInt8> decodedImage = tf.dtypes.cast(tf.image.decodeJpeg(rawImage, DecodeJpeg.channels(3L)),
UInt8.class);
// Expand dimensions since the model expects images to have shape: [1, H, W,
// 3]
tf.withName("normalized_image").expandDims(decodedImage, tf.constant(0));
});
this.objectDetection = new GraphRunner(Arrays.asList("image_tensor"), FETCH_NAMES)
.withGraphDefinition(new ProtoBufGraphDefinition(new DefaultResourceLoader().getResource(modelUri), true));
this.outputConverter = outputConverter;
}
public List<ObjectDetection> detect(byte[] image) {
try (Tensor inputTensor = Tensor.create(image); GraphRunnerMemory memorize = new GraphRunnerMemory()) {
List<List<ObjectDetection>> out = this.imageNormalization.andThen(memorize)
.andThen(this.objectDetection)
.andThen(memorize)
.andThen(outputConverter)
.apply(Collections.singletonMap("raw_image", inputTensor));
return out.get(0);
}
}
@Override
public void close() {
this.imageNormalization.close();
this.objectDetection.close();
// this.outputConverter.close();
}
public static void main(String[] args) throws IOException {
String modelUri = "https://dl.bintray.com/big-data/generic/ssdlite_mobilenet_v2_coco_2018_05_09_frozen_inference_graph.pb";
String labelUri = "https://dl.bintray.com/big-data/generic/mscoco_label_map.pbtxt";
ObjectDetectionOutputConverter outputAdapter = new ObjectDetectionOutputConverter(
new DefaultResourceLoader().getResource(labelUri), 0.4f, FETCH_NAMES);
// byte[] inputImage =
// GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg");
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/wild-animals-15.jpg");
try (ObjectDetectionService2 objectDetectionService2 = new ObjectDetectionService2(modelUri, outputAdapter)) {
List<ObjectDetection> boza = objectDetectionService2.detect(inputImage);
System.out.println(boza);
}
}
}

View File

@@ -1,116 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.domain;
import java.util.Arrays;
import com.fasterxml.jackson.annotation.JsonInclude;
/**
* @author Christian Tzolov
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ObjectDetection {
private String name;
private float confidence;
private float x1;
private float y1;
private float x2;
private float y2;
private float[][] mask;
private int cid;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public float getConfidence() {
return confidence;
}
public void setConfidence(float confidence) {
this.confidence = confidence;
}
public float getX1() {
return x1;
}
public void setX1(float x1) {
this.x1 = x1;
}
public float getY1() {
return y1;
}
public void setY1(float y1) {
this.y1 = y1;
}
public float getX2() {
return x2;
}
public void setX2(float x2) {
this.x2 = x2;
}
public float getY2() {
return y2;
}
public void setY2(float y2) {
this.y2 = y2;
}
public int getCid() {
return cid;
}
public void setCid(int cid) {
this.cid = cid;
}
public float[][] getMask() {
return mask;
}
public void setMask(float[][] mask) {
this.mask = mask;
}
@Override
public String toString() {
return "ObjectDetection{" + "name='" + name + '\'' + ", confidence=" + confidence + ", x1=" + x1 + ", y1=" + y1
+ ", x2=" + x2 + ", y2=" + y2 + ", mask=" + Arrays.toString(mask) + ", cid=" + cid + '}';
}
}

View File

@@ -1,301 +0,0 @@
// Protocol messages for describing input data Examples for machine learning
// model training or inference.
syntax = "proto3";
import "feature.proto";
option cc_enable_arenas = true;
option java_outer_classname = "ExampleProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.example";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example";
package tensorflow;
// An Example is a mostly-normalized data format for storing data for
// training and inference. It contains a key-value store (features); where
// each key (string) maps to a Feature message (which is oneof packed BytesList,
// FloatList, or Int64List). This flexible and compact format allows the
// storage of large amounts of typed data, but requires that the data shape
// and use be determined by the configuration files and parsers that are used to
// read and write this format. That is, the Example is mostly *not* a
// self-describing format. In TensorFlow, Examples are read in row-major
// format, so any configuration that describes data with rank-2 or above
// should keep this in mind. For example, to store an M x N matrix of Bytes,
// the BytesList must contain M*N bytes, with M rows of N contiguous values
// each. That is, the BytesList value must store the matrix as:
// .... row 0 .... .... row 1 .... // ........... // ... row M-1 ....
//
// An Example for a movie recommendation application:
// features {
// feature {
// key: "age"
// value { float_list {
// value: 29.0
// }}
// }
// feature {
// key: "movie"
// value { bytes_list {
// value: "The Shawshank Redemption"
// value: "Fight Club"
// }}
// }
// feature {
// key: "movie_ratings"
// value { float_list {
// value: 9.0
// value: 9.7
// }}
// }
// feature {
// key: "suggestion"
// value { bytes_list {
// value: "Inception"
// }}
// }
// # Note that this feature exists to be used as a label in training.
// # E.g., if training a logistic regression model to predict purchase
// # probability in our learning tool we would set the label feature to
// # "suggestion_purchased".
// feature {
// key: "suggestion_purchased"
// value { float_list {
// value: 1.0
// }}
// }
// # Similar to "suggestion_purchased" above this feature exists to be used
// # as a label in training.
// # E.g., if training a linear regression model to predict purchase
// # price in our learning tool we would set the label feature to
// # "purchase_price".
// feature {
// key: "purchase_price"
// value { float_list {
// value: 9.99
// }}
// }
// }
//
// A conformant Example data set obeys the following conventions:
// - If a Feature K exists in one example with data type T, it must be of
// type T in all other examples when present. It may be omitted.
// - The number of instances of Feature K list data may vary across examples,
// depending on the requirements of the model.
// - If a Feature K doesn't exist in an example, a K-specific default will be
// used, if configured.
// - If a Feature K exists in an example but contains no items, the intent
// is considered to be an empty tensor and no default will be used.
message Example {
Features features = 1;
};
// A SequenceExample is an Example representing one or more sequences, and
// some context. The context contains features which apply to the entire
// example. The feature_lists contain a key, value map where each key is
// associated with a repeated set of Features (a FeatureList).
// A FeatureList thus represents the values of a feature identified by its key
// over time / frames.
//
// Below is a SequenceExample for a movie recommendation application recording a
// sequence of ratings by a user. The time-independent features ("locale",
// "age", "favorites") describing the user are part of the context. The sequence
// of movies the user rated are part of the feature_lists. For each movie in the
// sequence we have information on its name and actors and the user's rating.
// This information is recorded in three separate feature_list(s).
// In the example below there are only two movies. All three feature_list(s),
// namely "movie_ratings", "movie_names", and "actors" have a feature value for
// both movies. Note, that "actors" is itself a bytes_list with multiple
// strings per movie.
//
// context: {
// feature: {
// key : "locale"
// value: {
// bytes_list: {
// value: [ "pt_BR" ]
// }
// }
// }
// feature: {
// key : "age"
// value: {
// float_list: {
// value: [ 19.0 ]
// }
// }
// }
// feature: {
// key : "favorites"
// value: {
// bytes_list: {
// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ]
// }
// }
// }
// }
// feature_lists: {
// feature_list: {
// key : "movie_ratings"
// value: {
// feature: {
// float_list: {
// value: [ 4.5 ]
// }
// }
// feature: {
// float_list: {
// value: [ 5.0 ]
// }
// }
// }
// }
// feature_list: {
// key : "movie_names"
// value: {
// feature: {
// bytes_list: {
// value: [ "The Shawshank Redemption" ]
// }
// }
// feature: {
// bytes_list: {
// value: [ "Fight Club" ]
// }
// }
// }
// }
// feature_list: {
// key : "actors"
// value: {
// feature: {
// bytes_list: {
// value: [ "Tim Robbins", "Morgan Freeman" ]
// }
// }
// feature: {
// bytes_list: {
// value: [ "Brad Pitt", "Edward Norton", "Helena Bonham Carter" ]
// }
// }
// }
// }
// }
//
// A conformant SequenceExample data set obeys the following conventions:
//
// Context:
// - All conformant context features K must obey the same conventions as
// a conformant Example's features (see above).
// Feature lists:
// - A FeatureList L may be missing in an example; it is up to the
// parser configuration to determine if this is allowed or considered
// an empty list (zero length).
// - If a FeatureList L exists, it may be empty (zero length).
// - If a FeatureList L is non-empty, all features within the FeatureList
// must have the same data type T. Even across SequenceExamples, the type T
// of the FeatureList identified by the same key must be the same. An entry
// without any values may serve as an empty feature.
// - If a FeatureList L is non-empty, it is up to the parser configuration
// to determine if all features within the FeatureList must
// have the same size. The same holds for this FeatureList across multiple
// examples.
// - For sequence modeling, e.g.:
// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
// https://github.com/tensorflow/nmt
// the feature lists represent a sequence of frames.
// In this scenario, all FeatureLists in a SequenceExample have the same
// number of Feature messages, so that the ith element in each FeatureList
// is part of the ith frame (or time step).
// Examples of conformant and non-conformant examples' FeatureLists:
//
// Conformant FeatureLists:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
//
// Non-conformant FeatureLists (mismatched types):
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { int64_list: { value: [ 5 ] } } }
// } }
//
// Conditionally conformant FeatureLists, the parser configuration determines
// if the feature sizes must match:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0, 6.0 ] } } }
// } }
//
// Conformant pair of SequenceExample
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } }
// feature: { float_list: { value: [ 2.0 ] } } }
// } }
//
// Conformant pair of SequenceExample
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { }
// } }
//
// Conditionally conformant pair of SequenceExample, the parser configuration
// determines if the second feature_lists is consistent (zero-length) or
// invalid (missing "movie_ratings"):
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { }
//
// Non-conformant pair of SequenceExample (mismatched types)
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { int64_list: { value: [ 4 ] } }
// feature: { int64_list: { value: [ 5 ] } }
// feature: { int64_list: { value: [ 2 ] } } }
// } }
//
// Conditionally conformant pair of SequenceExample; the parser configuration
// determines if the feature sizes must match:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.0 ] } }
// feature: { float_list: { value: [ 5.0, 3.0 ] } }
// } }
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};

View File

@@ -1,105 +0,0 @@
// Protocol messages for describing features for machine learning model
// training or inference.
//
// There are three base Feature types:
// - bytes
// - float
// - int64
//
// A Feature contains Lists which may hold zero or more values. These
// lists are the base values BytesList, FloatList, Int64List.
//
// Features are organized into categories by name. The Features message
// contains the mapping from name to Feature.
//
// Example Features for a movie recommendation application:
// feature {
// key: "age"
// value { float_list {
// value: 29.0
// }}
// }
// feature {
// key: "movie"
// value { bytes_list {
// value: "The Shawshank Redemption"
// value: "Fight Club"
// }}
// }
// feature {
// key: "movie_ratings"
// value { float_list {
// value: 9.0
// value: 9.7
// }}
// }
// feature {
// key: "suggestion"
// value { bytes_list {
// value: "Inception"
// }}
// }
// feature {
// key: "suggestion_purchased"
// value { int64_list {
// value: 1
// }}
// }
// feature {
// key: "purchase_price"
// value { float_list {
// value: 9.99
// }}
// }
//
syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "FeatureProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.example";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example";
package tensorflow;
// Containers to hold repeated fundamental values.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
// Containers for sequential data.
//
// A FeatureList contains lists of Features. These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name. The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};

View File

@@ -1,25 +0,0 @@
// Message to store the mapping from class label strings to class id. Datasets
// use string labels to represent classes while the object detection framework
// works fromMemory class ids. This message maps them so they can be converted back
// and forth as needed.
syntax = "proto2";
package org.springframework.cloud.fn.object.detection.protos;
message StringIntLabelMapItem {
// String name. The most common practice is to set this to a MID or synsets
// id. Synset: a set of synonyms that share a common meaning.
// https://en.wikipedia.org/wiki/WordNet
optional string name = 1;
// Integer id that maps to the string name above. Label ids should start
// from 1.
optional int32 id = 2;
// Human readable string label.
optional string display_name = 3;
};
message StringIntLabelMap {
repeated StringIntLabelMapItem item = 1;
};

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -1,102 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.examples;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.JsonMapperFunction;
import org.springframework.cloud.fn.object.detection.ObjectDetectionImageAugmenter;
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.ResourceLoader;
/**
* 4 of the pre-trained model in the model zoo
* (https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)
* can also compute the masks of the detected objects, providing instance segmentation.
* <p>
* Here are the models that can be used for instance segmentation.
* <p>
* mask_rcnn_inception_resnet_v2_atrous_coco 771 36 Masks mask_rcnn_inception_v2_coco 79
* 25 Masks mask_rcnn_resnet101_atrous_coco 470 33 Masks mask_rcnn_resnet50_atrous_coco
* 343 29 Masks
*
* @author Christian Tzolov
*/
public class ExampleInstanceSegmentation {
public static void main(String[] args) throws IOException {
ResourceLoader resourceLoader = new DefaultResourceLoader();
// You can download pre-trained models directly from the zoo:
// https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
// Just use the notation <zoo model tar.gz url>#<name of the frozen model file
// name>
// For performance reasons you may consider downloading the model locally and use
// the file:/<path to my model> URI instead!
String model = "https://download.tensorflow.org/models/object_detection/mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28.tar.gz#frozen_inference_graph.pb";
// All labels for the pre-trained models are available at:
// https://github.com/tensorflow/models/tree/master/research/object_detection/data
// Use the labels applicable for the model.
// Also, for performance reasons you may consider to download the labels and load
// them from file: instead.
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
// You can cache the TF model on the local file system to improve the bootstrap
// performance on consecutive runs!
boolean CACHE_TF_MODEL = true;
// For the pre-trained models fromMemory mask you can set the
// INSTANCE_SEGMENTATION to enable object instance segmentation as well
boolean INSTANCE_SEGMENTATION = true;
// Only object fromMemory confidence above the threshold are returned
float CONFIDENCE_THRESHOLD = 0.4f;
ObjectDetectionService detectionService = new ObjectDetectionService(model, labels, CONFIDENCE_THRESHOLD,
INSTANCE_SEGMENTATION, CACHE_TF_MODEL);
// You can use file:, http: or classpath: to provide the path to the input image.
byte[] image = GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg");
// Returns a list ObjectDetection domain classes to allow programmatic accesses to
// the detected objects's metadata
List<ObjectDetection> detectedObjects = detectionService.detect(image);
// Get JSON representation of the detected objects
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects);
System.out.println(jsonObjectDetections);
// Draw the detected object metadata on top of the original image and store the
// result
byte[] annotatedImage = new ObjectDetectionImageAugmenter(INSTANCE_SEGMENTATION).apply(image, detectedObjects);
File projectDir = new File("functions/function/object-detection-function");
File output = new File(projectDir, "target/object-detection-segmentation-augmented.jpg");
IOUtils.write(annotatedImage, new FileOutputStream(output));
System.out.println("Created:" + output.getPath());
}
}

View File

@@ -1,98 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.examples;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.JsonMapperFunction;
import org.springframework.cloud.fn.object.detection.ObjectDetectionImageAugmenter;
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.util.StreamUtils;
/**
* @author Christian Tzolov
*/
public class ExampleObjectDetection {
public static void main(String[] args) throws IOException {
// You can download pre-trained models directly from the zoo:
// https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
// Just use the notation <zoo model tar.gz url>#<name of the frozen model file
// name>
// For performance reasons you may consider downloading the model locally and use
// the file:/<path to my model> URI instead!
String model = "https://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz#frozen_inference_graph.pb";
// Resource model =
// resourceLoader.getResource("https://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_fgvc_2018_07_19.tar.gz#frozen_inference_graph.pb");
// Resource model =
// resourceLoader.getResource("https://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_fgvc_2018_07_19.tar.gz#frozen_inference_graph.pb");
// All labels for the pre-trained models are available at:
// https://github.com/tensorflow/models/tree/master/research/object_detection/data
// Use the labels applicable for the model.
// Also, for performance reasons you may consider to download the labels and load
// them from file: instead.
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
// Resource labels =
// resourceLoader.getResource("https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/fgvc_2854_classes_label_map.pbtxt");
// You can cache the TF model on the local file system to improve the bootstrap
// performance on consecutive runs!
boolean CACHE_TF_MODEL = true;
// For the pre-trained models fromMemory mask you can set the
// INSTANCE_SEGMENTATION to enable object instance segmentation as well
boolean NO_INSTANCE_SEGMENTATION = false;
// Only object fromMemory confidence above the threshold are returned
float CONFIDENCE_THRESHOLD = 0.4f;
ObjectDetectionService detectionService = new ObjectDetectionService(model, labels, CONFIDENCE_THRESHOLD,
NO_INSTANCE_SEGMENTATION, CACHE_TF_MODEL);
// You can use file:, http: or classpath: to provide the path to the input image.
String inputImageUri = "classpath:/images/object-detection.jpg";
try (InputStream is = new DefaultResourceLoader().getResource(inputImageUri).getInputStream()) {
byte[] image = StreamUtils.copyToByteArray(is);
// Returns a list ObjectDetection domain classes to allow programmatic
// accesses to the detected objects's metadata
List<ObjectDetection> detectedObjects = detectionService.detect(image);
// Get JSON representation of the detected objects
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects);
System.out.println(jsonObjectDetections);
// Draw the detected object metadata on top of the original image and store
// the result
byte[] annotatedImage = new ObjectDetectionImageAugmenter(NO_INSTANCE_SEGMENTATION).apply(image,
detectedObjects);
IOUtils.write(annotatedImage,
new FileOutputStream("./object-detection-function/target/object-detection-augmented.jpg"));
}
}
}

View File

@@ -1,62 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.examples;
import java.util.List;
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
/**
* @author Christian Tzolov
*/
public class SimpleExample {
public static void main(String[] args) {
// Select a pre-trained model from the model zoo:
// https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
// Just use the notation <model zoo url>#<name of the frozen model file in the
// zoo's tar.gz>
String model = "https://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz#frozen_inference_graph.pb";
// All labels for the pre-trained models are available at:
// https://github.com/tensorflow/models/tree/master/research/object_detection/data
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
ObjectDetectionService detectionService = new ObjectDetectionService(model, labels, 0.4f, // Only
// object
// fromMemory
// confidence
// above
// the
// threshold
// are
// returned.
// Confidence
// range
// is
// [0,
// 1].
false, // No instance segmentation
true); // cache the TF model locally
// You can use file:, http: or classpath: to provide the path to the input image.
List<ObjectDetection> detectedObjects = detectionService.detect("classpath:/images/object-detection.jpg");
detectedObjects.stream().map(o -> o.toString()).forEach(System.out::println);
}
}

View File

@@ -1,104 +0,0 @@
:images-asciidoc: https://raw.githubusercontent.com/spring-cloud/stream-applications/master/functions/function/semantic-segmentation-function/src/main/resources/images/
# Semantic Segmentation
[.lead]
Image Semantic Segmentation based on the state-of-art https://github.com/tensorflow/models/tree/master/research/deeplab[DeepLab] Tensorflow model.
[cols="1,2", frame=none, grid=none]
|===
| image:{images-asciidoc}/VikiMaxiAdi-all.png[width=100%]
|Semantic Segmentation is the process of associating each pixel of an image with a class label, (such as flower, person, road, sky, ocean, or car).
Unlike the `Instance Segmentation`, which produces instance-aware region masks, the `Semantic Segmentation` produces class-aware masks.
For implementing `Instance Segmentation` consult the https://github.com/spring-cloud/stream-applications/tree/master/functions/function/object-detection-function[Object Detection Service] instead.
|===
The https://github.com/spring-cloud/stream-applications/blob/master/functions/common/tensorflow-common/src/main/java/org/springframework/cloud/fn/common/tensorflow/deprecated/JsonMapperFunction.java[JsonMapperFunction] permits
converting the `List<ObjectDetection>` into JSON objects, and the
https://github.com/spring-cloud/stream-applications/blob/master/functions/function/object-detection-function/src/main/java/org/springframework/cloud/fn/object/detection/ObjectDetectionImageAugmenter.java[ObjectDetectionImageAugmenter]
allow to augment the input image with the detected bounding boxes and segmentation masks.
## Usage
Add the `semantic-segmentation` dependency to your pom (_use the latest version available_):
[source,xml]
----
<dependency>
<groupId>org.springframework.cloud.fn</groupId>
<artifactId>semantic-segmentation-function</artifactId>
<version>${revision}</version>
</dependency>
<dependency>
<groupId>org.springframework.cloud.fn</groupId>
<artifactId>object-detection-function</artifactId>
<version>${revision}</version>
</dependency>
----
Following snippet demos how to use the PASCAL VOC model to apply mask to an input image
[source,java,linenums]
----
SemanticSegmentation segmentationService = new SemanticSegmentation(
"https://download.tensorflow.org/models/deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz#frozen_inference_graph.pb", // <1>
true); // <2>
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/VikiMaxiAdi.jpg"); // <3>
byte[] imageMask = segmentationService.masksAsImage(inputImage); // <4>
BufferedImage bi = ImageIO.read(new ByteArrayInputStream(imageMask));
ImageIO.write(bi, "png", new FileOutputStream("./semantic-segmentation-function/target/VikiMaxiAdi_masks.png"));
byte[] augmentedImage = segmentationService.augment(inputImage); // <5>
IOUtils.write(augmentedImage, new FileOutputStream("./semantic-segmentation-function/target/VikiMaxiAdi_augmented.jpg"));
----
<1> Download the PASCAL 2012 trained model directly from the web. The `frozen_inference_graph.pb` is the name of the model
file inside the `tar.gz` archive.
<2> Cache the downloaded model locally
<3> Load the input image as byte array
<4> Read get the segmentation mask as separate image
<5> Blend the segmentation mask on top of the original image
## Models
Based on the training datasets, three groups of pre-trained models provided:
[cols="1,2", frame=none, grid=none]
|===
| image:{images-asciidoc}/VikiMaxiAdi-all.png[width=100%]
| https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md#deeplab-models-trained-on-pascal-voc-2012[DeepLab models trained on PASCAL VOC 2012]
| image:{images-asciidoc}/cityscape-all-small.png[width=100%]
| https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md#deeplab-models-trained-on-cityscapes[DeepLab models trained on Cityscapes]
| image:{images-asciidoc}/ADE20K-all-small.png[width=100%]
| https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md#deeplab-models-trained-on-ade20k[DeepLab models trained on ADE20K]
|===
Select the model you want to use, copy its archive download Url and add a `#frozen_inference_graph.pb` fragment to it.
Later fragment is the frozen model's file name inside the archive
TIP: Download the archive and uncompress the `frozen_inference_graph.pb` for required model. Then use the `file://<local-file-name>` URI schema.
Also, convenience there are a couple of models, extracted from the archive and uploaded to bintray:
[cols=2*,, frame=none, grid=none]
|===
|PASCAL VOC 2012 (default)
|https://dl.bintray.com/big-data/generic/deeplabv3_mnv2_pascal_train_aug_frozen_inference_graph.pb
|CITYSCAPE
|https://dl.bintray.com/big-data/generic/deeplabv3_mnv2_cityscapes_train_2018_02_05_frozen_inference_graph.pb
|ADE20K
|https://dl.bintray.com/big-data/generic/deeplabv3_xception_ade20k_train_2018_05_29_frozen_inference_graph.pb
|===
## References:
[.small]
* https://ai.googleblog.com/2018/03/semantic-image-segmentation-with.html[Semantic Image Segmentation with DeepLab in TensorFlow]
* https://github.com/tensorflow/models/tree/master/research/deeplab[DeepLab Project]
* https://medium.freecodecamp.org/how-to-use-deeplab-in-tensorflow-for-object-segmentation-using-deep-learning-a5777290ab6b[How to re-train DeepLab Segmentation models using Transfer Learning]

View File

@@ -1,3 +0,0 @@
dependencies {
api project(':spring-tensorflow-common')
}

View File

@@ -1,114 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.semantic.segmentation;
import java.util.Arrays;
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.ExpandDims;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.ReduceMax;
import org.tensorflow.op.core.Tile;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mul;
import org.tensorflow.op.math.Sub;
/**
* @author Christian Tzolov
*/
public final class NativeImageUtils {
private NativeImageUtils() {
}
/**
* grayscaleToRgb.
* https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/image_ops_impl.py#L1536
*/
public static <T> Operand<T> grayscaleToRgb(Ops tf, Operand<T> images) {
ExpandDims<Integer> rank_1 = tf.expandDims(tf.math.sub(tf.rank(images), tf.constant(1)), tf.constant(0));
// Create once 1D vector of the shape defined by the rank_1.
// E.g. for rank [2] will produce matrix [1, 1]. For [3] rank will produce a cube
// [1, 1, 1]
Add<Integer> ones = tf.math.add(tf.zeros(rank_1, Integer.class), tf.constant(1));
// Convert scalar 3 into 1D array [3]
ExpandDims<Integer> channelsAs1D = tf.expandDims(tf.constant(3), tf.constant(0));
Concat<Integer> shapeList = tf.concat(Arrays.asList(ones, channelsAs1D), tf.constant(0));
Tile<T> tile = tf.withName("grayscaleToRgb").tile(images, shapeList);
return tile;
}
public static Operand<Float> normalizeMask(Ops tf, Operand<Float> mask, float newValue) {
// generate array representing the axis indexes.
// For example of tensor of rank K the axisRange is {0, 1, 2 ...K}
Range<Integer> axisRange = tf.range(tf.constant(0), // from
tf.dtypes.cast(tf.rank(mask), Integer.class), // to
tf.constant(1)); // step
ReduceMax<Float> max = tf.reduceMax(mask, axisRange);
// Mul<Float> input2Float1 = tf.math.mul(tf.math.div(input2Float, max),
// tf.constant(1f));
Mul<Float> normalizedMask = tf.math.mul(tf.math.div(mask, max), tf.constant(newValue));
return normalizedMask;
}
/**
* Alpha Blending . https://en.wikipedia.org/wiki/Alpha_compositing#Alpha_blending
*/
public static Operand<Float> alphaBlending(Ops tf, Operand<Float> srcRgb, Operand<Float> dstRgb,
Operand<Float> srcAlpha) {
Sub<Float> alpha = tf.math.sub(tf.onesLike(srcRgb), srcAlpha);
Mul<Float> src = tf.math.mul(srcRgb, alpha);
Mul<Float> dst = tf.math.mul(dstRgb, tf.math.sub(tf.constant(1.0f), alpha));
Add<Float> out = tf.math.add(dst, src);
// Mul<Float> out = tf.math.mul(srcRgbNormalized, dstRgb);
// Squeeze<Float> squeeze = tf.withName("squeeze").squeeze(out,
// Squeeze.axis(Arrays.asList(0L)));
return out;
}
/**
* The mask can contain label values larger than the list of colors provided in the
* color map. To avoid out-of-index errors we will "normalize" the label values in the
* mask to MOD max-color-table-value.
* @param tf - tensorflow
* @param colorTable Color map of shape [n, 3]. n is the count of label entries and 3
* is the RGB color assigned to that label.
* @param mask Mask of shape [h, w] containing label vales.
* @return Mask of shape [h, w] fromMemory values normalized between [0, n]
*/
public static Operand<Long> normalizeMaskLabels(Ops tf, Operand<Integer> colorTable, Operand<Long> mask) {
// The mask can contain label values larger than the list of colors provided in
// the color map.
// To avoid out-of-index errors we will "normalize" the label values in the mask
// to MOD max-color-table-value.
Sub<Long> colorTableShape = tf.math.sub(tf.shape(colorTable, Long.class), tf.constant(1L));
// Color tables have shape [N, 3], where N is the count of label entries.
// Therefore the max label id is (N - 1).
Gather<Long> colorTableSize = tf.gather(colorTableShape, tf.constant(new int[] { 0 }), tf.constant(0));
// Normalize the label values in the mask so they don't exceed the max value in
// the color map.
return tf.math.mod(mask, colorTableSize);
}
}

View File

@@ -1,162 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.semantic.segmentation;
import java.io.InputStream;
import java.util.Arrays;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.core.io.DefaultResourceLoader;
/**
*
* Visualizes the segmentation results via specified color map. Color maps helping to
* visualize the semantic segmentation results for the different datasets.
*
* Supported colormaps are: - ADE20K
* (http://groups.csail.mit.edu/vision/datasets/ADE20K/). - Cityscapes dataset
* (https://www.cityscapes-dataset.com). - Mapillary Vistas
* (https://research.mapillary.com). - PASCAL VOC 2012
* (http://host.robots.ox.ac.uk/pascal/VOC/).
*
* Based on:
* https://github.com/tensorflow/models/blob/master/research/deeplab/utils/get_dataset_colormap.py
*
* @author Christian Tzolov
*/
public final class SegmentationColorMap {
private SegmentationColorMap() {
}
/** MAPILLARY_COLORMAP . */
public static final int[][] MAPILLARY_COLORMAP = new int[][] { { 165, 42, 42 }, { 0, 192, 0 }, { 196, 196, 196 },
{ 190, 153, 153 }, { 180, 165, 180 }, { 102, 102, 156 }, { 102, 102, 156 }, { 128, 64, 255 },
{ 140, 140, 200 }, { 170, 170, 170 }, { 250, 170, 160 }, { 96, 96, 96 }, { 230, 150, 140 },
{ 128, 64, 128 }, { 110, 110, 110 }, { 244, 35, 232 }, { 150, 100, 100 }, { 70, 70, 70 }, { 150, 120, 90 },
{ 220, 20, 60 }, { 255, 0, 0 }, { 255, 0, 0 }, { 255, 0, 0 }, { 200, 128, 128 }, { 255, 255, 255 },
{ 64, 170, 64 }, { 128, 64, 64 }, { 70, 130, 180 }, { 255, 255, 255 }, { 152, 251, 152 }, { 107, 142, 35 },
{ 0, 170, 30 }, { 255, 255, 128 }, { 250, 0, 30 }, { 0, 0, 0 }, { 220, 220, 220 }, { 170, 170, 170 },
{ 222, 40, 40 }, { 100, 170, 30 }, { 40, 40, 40 }, { 33, 33, 33 }, { 170, 170, 170 }, { 0, 0, 142 },
{ 170, 170, 170 }, { 210, 170, 100 }, { 153, 153, 153 }, { 128, 128, 128 }, { 0, 0, 142 }, { 250, 170, 30 },
{ 192, 192, 192 }, { 220, 220, 0 }, { 180, 165, 180 }, { 119, 11, 32 }, { 0, 0, 142 }, { 0, 60, 100 },
{ 0, 0, 142 }, { 0, 0, 90 }, { 0, 0, 230 }, { 0, 80, 100 }, { 128, 64, 64 }, { 0, 0, 110 }, { 0, 0, 70 },
{ 0, 0, 192 }, { 32, 32, 32 }, { 0, 0, 0 }, { 0, 0, 0 }, };
/**
* Label colormap used in ADE20K segmentation benchmark.
*/
public static final int[][] ADE20K_COLORMAP = new int[][] { { 0, 0, 0 }, { 120, 120, 120 }, { 180, 120, 120 },
{ 6, 230, 230 }, { 80, 50, 50 }, { 4, 200, 3 }, { 120, 120, 80 }, { 140, 140, 140 }, { 204, 5, 255 },
{ 230, 230, 230 }, { 4, 250, 7 }, { 224, 5, 255 }, { 235, 255, 7 }, { 150, 5, 61 }, { 120, 120, 70 },
{ 8, 255, 51 }, { 255, 6, 82 }, { 143, 255, 140 }, { 204, 255, 4 }, { 255, 51, 7 }, { 204, 70, 3 },
{ 0, 102, 200 }, { 61, 230, 250 }, { 255, 6, 51 }, { 11, 102, 255 }, { 255, 7, 71 }, { 255, 9, 224 },
{ 9, 7, 230 }, { 220, 220, 220 }, { 255, 9, 92 }, { 112, 9, 255 }, { 8, 255, 214 }, { 7, 255, 224 },
{ 255, 184, 6 }, { 10, 255, 71 }, { 255, 41, 10 }, { 7, 255, 255 }, { 224, 255, 8 }, { 102, 8, 255 },
{ 255, 61, 6 }, { 255, 194, 7 }, { 255, 122, 8 }, { 0, 255, 20 }, { 255, 8, 41 }, { 255, 5, 153 },
{ 6, 51, 255 }, { 235, 12, 255 }, { 160, 150, 20 }, { 0, 163, 255 }, { 140, 140, 140 }, { 250, 10, 15 },
{ 20, 255, 0 }, { 31, 255, 0 }, { 255, 31, 0 }, { 255, 224, 0 }, { 153, 255, 0 }, { 0, 0, 255 },
{ 255, 71, 0 }, { 0, 235, 255 }, { 0, 173, 255 }, { 31, 0, 255 }, { 11, 200, 200 }, { 255, 82, 0 },
{ 0, 255, 245 }, { 0, 61, 255 }, { 0, 255, 112 }, { 0, 255, 133 }, { 255, 0, 0 }, { 255, 163, 0 },
{ 255, 102, 0 }, { 194, 255, 0 }, { 0, 143, 255 }, { 51, 255, 0 }, { 0, 82, 255 }, { 0, 255, 41 },
{ 0, 255, 173 }, { 10, 0, 255 }, { 173, 255, 0 }, { 0, 255, 153 }, { 255, 92, 0 }, { 255, 0, 255 },
{ 255, 0, 245 }, { 255, 0, 102 }, { 255, 173, 0 }, { 255, 0, 20 }, { 255, 184, 184 }, { 0, 31, 255 },
{ 0, 255, 61 }, { 0, 71, 255 }, { 255, 0, 204 }, { 0, 255, 194 }, { 0, 255, 82 }, { 0, 10, 255 },
{ 0, 112, 255 }, { 51, 0, 255 }, { 0, 194, 255 }, { 0, 122, 255 }, { 0, 255, 163 }, { 255, 153, 0 },
{ 0, 255, 10 }, { 255, 112, 0 }, { 143, 255, 0 }, { 82, 0, 255 }, { 163, 255, 0 }, { 255, 235, 0 },
{ 8, 184, 170 }, { 133, 0, 255 }, { 0, 255, 92 }, { 184, 0, 255 }, { 255, 0, 31 }, { 0, 184, 255 },
{ 0, 214, 255 }, { 255, 0, 112 }, { 92, 255, 0 }, { 0, 224, 255 }, { 112, 224, 255 }, { 70, 184, 160 },
{ 163, 0, 255 }, { 153, 0, 255 }, { 71, 255, 0 }, { 255, 0, 163 }, { 255, 204, 0 }, { 255, 0, 143 },
{ 0, 255, 235 }, { 133, 255, 0 }, { 255, 0, 235 }, { 245, 0, 255 }, { 255, 0, 122 }, { 255, 245, 0 },
{ 10, 190, 212 }, { 214, 255, 0 }, { 0, 204, 255 }, { 20, 0, 255 }, { 255, 255, 0 }, { 0, 153, 255 },
{ 0, 41, 255 }, { 0, 255, 204 }, { 41, 0, 255 }, { 41, 255, 0 }, { 173, 0, 255 }, { 0, 245, 255 },
{ 71, 0, 255 }, { 122, 0, 255 }, { 0, 255, 184 }, { 0, 92, 255 }, { 184, 255, 0 }, { 0, 133, 255 },
{ 255, 214, 0 }, { 25, 194, 194 }, { 102, 255, 0 }, { 92, 0, 255 }, };
/** BLACK_WHITE_COLORMAP . */
public static int[][] BLACK_WHITE_COLORMAP = new int[][] { { 0, 0, 0 }, { 127, 127, 127 }, { 255, 255, 255 }, };
/** CITYMAP_COLORMAP . */
public static final int[][] CITYMAP_COLORMAP = new int[255][3];
static {
// Initialize citymap
int[][] _CITYMAP_COLORMAP = new int[][] { { 128, 64, 128 }, { 244, 35, 232 }, { 70, 70, 70 }, { 102, 102, 156 },
{ 190, 153, 153 }, { 153, 153, 153 }, { 250, 170, 30 }, { 220, 220, 0 }, { 107, 142, 35 },
{ 152, 251, 152 }, { 70, 130, 180 }, { 220, 20, 60 }, { 255, 0, 0 }, { 0, 0, 142 }, { 0, 0, 70 },
{ 0, 60, 100 }, { 0, 80, 100 }, { 0, 0, 230 }, { 119, 11, 32 } };
for (int i = 0; i < _CITYMAP_COLORMAP.length; i++) {
System.arraycopy(_CITYMAP_COLORMAP[i], 0, CITYMAP_COLORMAP[i], 0, _CITYMAP_COLORMAP[i].length);
}
}
public static int[][] loadColorMap(String resourceUri) {
try {
InputStream colorMapIs = new DefaultResourceLoader().getResource(resourceUri).getInputStream();
ColorMap colorMap = new ObjectMapper().readValue(colorMapIs, ColorMap.class);
return colorMap.getColormap();
}
catch (Exception exception) {
throw new RuntimeException(exception);
}
}
public static class ColorMap {
private String name;
private String info;
private int[][] colormap;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public String getInfo() {
return info;
}
public void setInfo(String info) {
this.info = info;
}
public int[][] getColormap() {
return colormap;
}
public void setColormap(int[][] colormap) {
this.colormap = colormap;
}
@Override
public String toString() {
return "ColorMap{" + "name='" + name + '\'' + "info='" + info + '\'' + ", colormap="
+ Arrays.deepToString(colormap) + '}';
}
}
}

View File

@@ -1,348 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.semantic.segmentation;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import javax.imageio.ImageIO;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Gather;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Squeeze;
import org.tensorflow.op.core.ZerosLike;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.image.DecodeJpeg;
import org.tensorflow.op.image.ExtractJpegShape;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Div;
import org.tensorflow.op.math.Equal;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.Functions;
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
import org.springframework.cloud.fn.common.tensorflow.GraphRunnerMemory;
import org.springframework.cloud.fn.common.tensorflow.ProtoBufGraphDefinition;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.core.io.DefaultResourceLoader;
/**
* @author Christian Tzolov
*/
public class SemanticSegmentation implements AutoCloseable {
private static final long CHANNELS = 3;
private static final float REQUIRED_INPUT_IMAGE_SIZE = 513f;
private final GraphRunner imageNormalization;
private final GraphRunner semanticSegmentation;
private final GraphRunner maskImageEncoding;
private final GraphRunner alphaBlending;
private final Tensor<Integer> colorMapTensor;
private final Tensor<Float> maskTransparencyTensor;
@Override
public void close() {
this.imageNormalization.close();
this.semanticSegmentation.close();
this.maskImageEncoding.close();
this.alphaBlending.close();
this.colorMapTensor.close();
this.maskTransparencyTensor.close();
}
public SemanticSegmentation(String modelUrl, int[][] colorMap, long[] labelFilter, float maskTransparency) {
this.imageNormalization = new GraphRunner("input_image", "resized_image").withGraphDefinition(tf -> {
Placeholder<String> input = tf.withName("input_image").placeholder(String.class);
ExtractJpegShape<Integer> imageShapeAndChannel = tf.image.extractJpegShape(input);
Gather<Integer> imageShape = tf.gather(imageShapeAndChannel, tf.constant(new int[] { 0, 1 }),
tf.constant(0));
Cast<Float> maxSize = tf.dtypes.cast(tf.max(imageShape, tf.constant(0)), Float.class);
Div<Float> scale = tf.math.div(tf.constant(REQUIRED_INPUT_IMAGE_SIZE), maxSize);
Cast<Integer> newSize = tf.dtypes.cast(tf.math.mul(scale, tf.dtypes.cast(imageShape, Float.class)),
Integer.class);
final Operand<Float> decodedImage = tf.dtypes
.cast(tf.image.decodeJpeg(input, DecodeJpeg.channels(CHANNELS)), Float.class);
final Operand<Float> resizedImageFloat = tf.image
.resizeBilinear(tf.expandDims(decodedImage, tf.constant(0)), newSize);
tf.withName("resized_image").dtypes.cast(resizedImageFloat, UInt8.class);
});
this.semanticSegmentation = new GraphRunner("ImageTensor:0", "SemanticPredictions:0")
.withGraphDefinition(new ProtoBufGraphDefinition(new DefaultResourceLoader().getResource(modelUrl), true));
this.colorMapTensor = Tensor.create(colorMap).expect(Integer.class);
this.maskImageEncoding = new GraphRunner(Arrays.asList("color_map", "mask_pixels"),
Arrays.asList("mask_png", "mask_rgb"))
.withGraphDefinition(tf -> {
Placeholder<Integer> colorTable = tf.withName("color_map").placeholder(Integer.class);
Placeholder<Long> batchedMask = tf.withName("mask_pixels").placeholder(Long.class);
// Remove batch dimension
Squeeze<Long> mask = tf.squeeze(batchedMask, Squeeze.axis(Arrays.asList(0L)));
Operand<Long> filteredMask = labelFilter(tf, mask, labelFilter);
// The mask can contain label values larger than the list of colors
// provided in the color map.
// To avoid out-of-index errors we will "normalize" the label values in
// the mask to MOD max-color-table-value.
Operand<Long> mask3 = NativeImageUtils.normalizeMaskLabels(tf, colorTable, filteredMask);
Gather<Integer> maskRgb = tf.withName("mask_rgb").gather(colorTable, mask3, tf.constant(0));
Operand<String> png = tf.withName("mask_png").image.encodePng(tf.dtypes.cast(maskRgb, UInt8.class));
});
this.maskTransparencyTensor = Tensor.create(maskTransparency).expect(Float.class);
this.alphaBlending = new GraphRunner(Arrays.asList("input_image", "mask_image", "mask_transparency"),
Arrays.asList("blended_png"))
.withGraphDefinition(tf -> {
// Input image [B, H, W, 3]
Cast<Float> inputImageRgb = tf.dtypes.cast(tf.withName("input_image").placeholder(UInt8.class),
Float.class);
Placeholder<Integer> a = tf.withName("mask_image").placeholder(Integer.class);
Cast<Float> maskRgb = tf.dtypes.cast(a, Float.class);
Squeeze<Float> inputImageRgb2 = tf.squeeze(inputImageRgb, Squeeze.axis(Arrays.asList(0L)));
Placeholder<Float> maskTransparencyHolder = tf.withName("mask_transparency").placeholder(Float.class);
// Blend the transparent maskImage on top of the input image.
Operand<Float> blended = NativeImageUtils.alphaBlending(tf, maskRgb, inputImageRgb2,
maskTransparencyHolder);
// Cut
// Operand<Boolean> condition = tf.math.equal(a, tf.zerosLike(a));
// Operand<Float> blended = tf.where3(condition, tf.zerosLike(maskRgb),
// inputImageRgb2);
// Encode PNG
tf.withName("blended_png").image.encodePng(tf.dtypes.cast(blended, UInt8.class));
});
}
public byte[] blendMask(byte[] image) {
try (Tensor inputTensor = Tensor.create(image); GraphRunnerMemory memory = new GraphRunnerMemory()) {
Map<String, Tensor<?>> blendedTensors = this.imageNormalization.andThen(memory) // (input_image)
// ->
// (resized_image)
// and
// memorize
// (resized_image)
.andThen(this.semanticSegmentation)
.andThen(memory) // (ImageTensor:0) -> (SemanticPredictions:0) and
// memorize (SemanticPredictions:0)
.andThen(Functions.rename("SemanticPredictions:0", "mask_pixels")) // (SemanticPredictions:0)
// ->
// (mask_pixels)
.andThen(Functions.enrichWith("color_map", this.colorMapTensor)) // (mask_pixels)
// ->
// (mask_pixels,
// color_map)
.andThen(this.maskImageEncoding)
.andThen(memory) // (color_map, mask_pixels) -> (mask_png, mask_rgb) and
// memorize (mask_png, mask_rgb)
.andThen(Functions.enrichFromMemory(memory, "resized_image")) // (mask_png,
// mask_rgb)
// ->
// (mask_png,
// mask_rgb,
// resized_image),
// e.g.
// join
// the
// normalizedImageTensor
.andThen(Functions.rename("resized_image", "input_image", "mask_rgb", "mask_image")) // (mask_png,
// mask_rgb,
// resized_image)
// ->
// (mask_image,
// input_image)
.andThen(Functions.enrichWith("mask_transparency", this.maskTransparencyTensor)) // (mask_image,
// input_image)
// ->
// (mask_image,
// input_image,
// mask_transparency)
.andThen(this.alphaBlending)
.andThen(memory) // (mask_image, input_image, mask_transparency) ->
// (blended_png)
.apply(Collections.singletonMap("input_image", inputTensor)); // () ->
// (input_image)
byte[] blendedImage = blendedTensors.get("blended_png").bytesValue();
memory.getTensorMap().entrySet().stream().forEach(e -> System.out.println(e));
return blendedImage;
}
}
public long[][] maskPixels(byte[] image) {
try (Tensor inputTensor = Tensor.create(image); GraphRunnerMemory memory = new GraphRunnerMemory()) {
return this.imageNormalization.andThen(memory) // (input_image) ->
// (resized_image) and
// memorize (resized_image)
.andThen(this.semanticSegmentation)
.andThen(memory) // (ImageTensor:0) -> (SemanticPredictions:0) and
// memorize (SemanticPredictions:0)
.andThen(tensorMap -> {
Tensor<?> maskTensor = tensorMap.get("SemanticPredictions:0");
int width = (int) maskTensor.shape()[1];
int height = (int) maskTensor.shape()[2];
return maskTensor.copyTo(new long[1][width][height])[0]; // 1 == batch
// size
})
.apply(Collections.singletonMap("input_image", inputTensor)); // () ->
// (input_image)
}
}
public byte[] maskImage(byte[] image) {
try (Tensor inputTensor = Tensor.create(image); GraphRunnerMemory memory = new GraphRunnerMemory()) {
return this.imageNormalization.andThen(memory) // (input_image) ->
// (resized_image) and
// memorize (resized_image)
.andThen(this.semanticSegmentation)
.andThen(memory) // (ImageTensor:0) -> (SemanticPredictions:0) and
// memorize (SemanticPredictions:0)
.andThen(Functions.rename("SemanticPredictions:0", "mask_pixels")) // (SemanticPredictions:0)
// ->
// (mask_pixels)
.andThen(Functions.enrichWith("color_map", this.colorMapTensor)) // (mask_pixels)
// ->
// (mask_pixels,
// color_map)
.andThen(this.maskImageEncoding)
.andThen(memory) // (color_map, mask_pixels) -> (mask_png, mask_rgb) and
// memorize (mask_png, mask_rgb)
.andThen(tensorMap -> tensorMap.get("mask_png").bytesValue())
.apply(Collections.singletonMap("input_image", inputTensor)); // () ->
// (input_image)
}
}
private Operand<Long> labelFilter(Ops tf, Operand<Long> mask, long[] labels) {
if (labels == null || labels.length == 0) {
return mask;
}
ZerosLike<Long> zeroMask = tf.zerosLike(mask);
Operand<Long> result = zeroMask;
for (long label : labels) {
Add<Long> labelMask = tf.math.add(tf.zerosLike(mask), tf.constant(label));
Equal condition = tf.math.equal(mask, labelMask);
result = tf.math.add(result, tf.where3(condition, labelMask, zeroMask));
}
return result;
}
public static void main(String[] args) throws IOException {
try (SemanticSegmentation segmentationService = new SemanticSegmentation(
"https://download.tensorflow.org/models/deeplabv3_mnv2_cityscapes_train_2018_02_05.tar.gz#frozen_inference_graph.pb",
SegmentationColorMap.loadColorMap("classpath:/colormap/citymap_colormap.json"), null, 0.45f)) {
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/amsterdam-cityscape1.jpg");
// 1. Mask pixels
long[][] maskPixels = segmentationService.maskPixels(inputImage);
String json = new ObjectMapper().writeValueAsString(maskPixels);
// 2. Alpha Blending
byte[] blended = segmentationService.blendMask(inputImage);
ImageIO.write(ImageIO.read(new ByteArrayInputStream(blended)), "png",
new File("./functions/function/semantic-segmentation-function/target/blendedImage.png"));
// 3. Mask Image
byte[] maskImage = segmentationService.maskImage(inputImage);
ImageIO.write(ImageIO.read(new ByteArrayInputStream(maskImage)), "png",
new File("./functions/function/semantic-segmentation-function/target/maskImage.png"));
}
try (SemanticSegmentation segmentationService = new SemanticSegmentation(
"https://download.tensorflow.org/models/deeplabv3_xception_ade20k_train_2018_05_29.tar.gz#frozen_inference_graph.pb",
SegmentationColorMap.loadColorMap("classpath:/colormap/ade20k_colormap.json"), null, 0.45f)) {
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/interior.jpg");
// 1. Mask pixels
long[][] maskPixels = segmentationService.maskPixels(inputImage);
// 2. Alpha Blending
byte[] blended = segmentationService.blendMask(inputImage);
ImageIO.write(ImageIO.read(new ByteArrayInputStream(blended)), "png",
new File("./functions/function/semantic-segmentation-function/target/inventory-blendedImage.png"));
// 3. Mask Image
byte[] maskImage = segmentationService.maskImage(inputImage);
ImageIO.write(ImageIO.read(new ByteArrayInputStream(maskImage)), "png",
new File("./functions/function/semantic-segmentation-function/target/inventory-MaskImage.png"));
}
try (SemanticSegmentation segmentationService = new SemanticSegmentation(
"https://download.tensorflow.org/models/deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz#frozen_inference_graph.pb",
SegmentationColorMap.loadColorMap("classpath:/colormap/black_white_colormap.json"), null, 0.45f)) {
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/VikiMaxiAdi.jpg");
// 1. Mask pixels
long[][] maskPixels = segmentationService.maskPixels(inputImage);
// 2. Alpha Blending
byte[] blended = segmentationService.blendMask(inputImage);
ImageIO.write(ImageIO.read(new ByteArrayInputStream(blended)), "png",
new File("./functions/function/semantic-segmentation-function/target/pascal-blendedImage.png"));
// 3. Mask Image
byte[] maskImage = segmentationService.maskImage(inputImage);
ImageIO.write(ImageIO.read(new ByteArrayInputStream(maskImage)), "png",
new File("./functions/function/semantic-segmentation-function/target/pascal-MaskImage.png"));
}
}
}

View File

@@ -1,305 +0,0 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.semantic.segmentation.attic;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.Map;
import javax.imageio.ImageIO;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.TensorFlowService;
import org.springframework.core.io.DefaultResourceLoader;
import static java.awt.image.BufferedImage.TYPE_3BYTE_BGR;
/**
*
* Semantic image segmentation - the task of assigning a semantic label, such as "road",
* "sky", "person", "dog", to every pixel in an image.
*
* https://ai.googleblog.com/2018/03/semantic-image-segmentation-with.html
* https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
* https://github.com/tensorflow/models/tree/master/research/deeplab
* https://github.com/tensorflow/models/blob/master/research/deeplab/deeplab_demo.ipynb
* http://presentations.cocodataset.org/Places17-GMRI.pdf
*
* http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html
* https://www.cityscapes-dataset.com/dataset-overview/#class-definitions
* http://groups.csail.mit.edu/vision/datasets/ADE20K/
*
* https://github.com/mapillary/inplace_abn
*
* @author Christian Tzolov
*/
public class SemanticSegmentationUtils {
/** INPUT_TENSOR_NAME . */
public static final String INPUT_TENSOR_NAME = "ImageTensor:0";
/** OUTPUT_TENSOR_NAME . */
public static final String OUTPUT_TENSOR_NAME = "SemanticPredictions:0";
private static final int BATCH_SIZE = 1;
private static final long CHANNELS = 3;
private static final int REQUIRED_INPUT_IMAGE_SIZE = 513;
public static BufferedImage scaledImage(String imagePath) {
try {
return scaledImage(ImageIO.read(new DefaultResourceLoader().getResource(imagePath).getInputStream()));
}
catch (IOException e) {
throw new IllegalStateException("Failed to load Image from: " + imagePath, e);
}
}
public static BufferedImage scaledImage(byte[] image) {
try {
return scaledImage(ImageIO.read(new ByteArrayInputStream(image)));
}
catch (IOException e) {
throw new IllegalStateException("Failed to load Image from byte array", e);
}
}
public static BufferedImage scaledImage(BufferedImage image) {
double scaleRatio = 1.0 * REQUIRED_INPUT_IMAGE_SIZE / Math.max(image.getWidth(), image.getHeight());
return scale(image, scaleRatio);
}
private static BufferedImage scale(BufferedImage originalImage, double scale) {
int newWidth = (int) (originalImage.getWidth() * scale);
int newHeight = (int) (originalImage.getHeight() * scale);
Image tmpImage = originalImage.getScaledInstance(newWidth, newHeight, Image.SCALE_DEFAULT);
// BufferedImage resizedImage = new BufferedImage(newWidth, newHeight,
// TYPE_INT_BGR);
BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, TYPE_3BYTE_BGR);
// BufferedImage resizedImage = new BufferedImage(newWidth, newHeight,
// originalImage.getType());
Graphics2D g2d = resizedImage.createGraphics();
g2d.drawImage(tmpImage, 0, 0, null);
g2d.dispose();
return resizedImage;
}
public static BufferedImage blendMask(BufferedImage mask, BufferedImage background) {
GraphicsUtils.overlayImages(background, mask, 0, 0);
return background;
}
public static Tensor<UInt8> createInputTensor(BufferedImage scaledImage) {
if (scaledImage.getType() != TYPE_3BYTE_BGR) {
throw new IllegalArgumentException(
String.format("Expected 3-byte BGR encoding in BufferedImage, found %d", scaledImage.getType()));
}
// ImageIO.read produces BGR-encoded images, while the model expects RGB.
byte[] data = bgrToRgb(toBytes(scaledImage));
// Expand dimensions since the model expects images to have shape: [1, None, None,
// 3]
long[] shape = new long[] { BATCH_SIZE, scaledImage.getHeight(), scaledImage.getWidth(), CHANNELS };
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}
private static byte[] bgrToRgb(byte[] brgImage) {
byte[] rgbImage = new byte[brgImage.length];
for (int i = 0; i < brgImage.length; i += 3) {
rgbImage[i] = brgImage[i + 2];
rgbImage[i + 1] = brgImage[i + 1];
rgbImage[i + 2] = brgImage[i];
}
return rgbImage;
}
private static byte[] toBytes(BufferedImage bufferedImage) {
return ((DataBufferByte) bufferedImage.getRaster().getDataBuffer()).getData();
}
public static BufferedImage createMaskImage(int[][] maskPixels, int width, int height, double transparency) {
maskPixels = rotate(maskPixels);
int maskWidth = maskPixels.length;
int maskHeight = maskPixels[0].length;
int[] maskArray = new int[maskWidth * maskHeight];
int k = 0;
for (int i = 0; i < maskHeight; i++) {
for (int j = 0; j < maskWidth; j++) {
Color c = (maskPixels[j][i] == 0) ? Color.BLACK : GraphicsUtils.getClassColor(maskPixels[j][i]);
int t = (int) (255 * (1 - transparency));
maskArray[k++] = new Color(c.getRed(), c.getGreen(), c.getBlue(), t).getRGB();
}
}
// Turn the pixel array into image;
BufferedImage maskImage = new BufferedImage(maskWidth, maskHeight, BufferedImage.TYPE_INT_ARGB);
maskImage.setRGB(0, 0, maskWidth, maskHeight, maskArray, 0, maskWidth);
// Stretch the image to fit the target box width and height!
return GraphicsUtils.toBufferedImage(maskImage.getScaledInstance(width, height, Image.SCALE_SMOOTH));
}
/**
* rotate clockwise in 90 degree.
* @param input The 2D matrix to be rotated
* @return The input matrix rotated clockwise in 90 degrees
*/
private static int[][] rotate(int[][] input) {
int w = input.length;
int h = input[0].length;
int[][] output = new int[h][w];
for (int y = 0; y < h; y++) {
for (int x = w - 1; x >= 0; x--) {
output[y][x] = input[x][y];
}
}
return output;
}
public static int[][] toIntArray(long[][] longArray) {
int[][] intArray = new int[longArray.length][longArray[0].length];
for (int i = 0; i < longArray.length; i++) {
for (int j = 0; j < longArray[0].length; j++) {
intArray[i][j] = (int) longArray[i][j];
}
}
return intArray;
}
public String serializeToJson(int[][] pixels) {
String masksBase64 = Base64.getEncoder().encodeToString(toBytes(pixels));
return String.format("{ \"columns\":%d, \"rows\":%d, \"masks\":\"%s\"}", pixels.length, pixels[0].length,
masksBase64);
}
public int[][] deserializeToMasks(String json) throws IOException {
Map<String, Object> map = new ObjectMapper().readValue(json, Map.class);
int cols = (int) map.get("columns");
int rows = (int) map.get("rows");
String masksBase64 = (String) map.get("masks");
byte[] masks = Base64.getDecoder().decode(masksBase64);
return toInts(masks, cols, rows);
}
private byte[] toBytes(int[][] pixels) {
byte[] b = new byte[pixels.length * pixels[0].length * 4];
int bi = 0;
for (int i = 0; i < pixels.length; i++) {
for (int j = 0; j < pixels[0].length; j++) {
b[bi + 0] = (byte) (i >> 24);
b[bi + 1] = (byte) (i >> 16);
b[bi + 2] = (byte) (i >> 8);
b[bi + 3] = (byte) (i /* >> 0 */);
bi = bi + 4;
}
}
return b;
}
private int[][] toInts(byte[] b, int ic, int jc) {
int[][] intResult = new int[ic][jc];
int bi = 0;
for (int i = 0; i < ic; i++) {
for (int j = 0; j < jc; j++) {
intResult[i][j] = (b[bi] << 24) + (b[bi + 1] << 16) + (b[bi + 2] << 8) + b[bi + 3];
bi = bi + 4;
}
}
return intResult;
}
public static void main(String[] args) throws IOException {
// PASCAL VOC 2012
// String tensorflowModelLocation =
// "file:/Users/ctzolov/Downloads/deeplabv3_mnv2_pascal_train_aug/frozen_inference_graph.pb";
// String imagePath = "classpath:/images/VikiMaxiAdi.jpg";
// CITYSCAPE
// String tensorflowModelLocation =
// "file:/Users/ctzolov/Downloads/deeplabv3_mnv2_cityscapes_train/frozen_inference_graph.pb";
// String imagePath = "classpath:/images/amsterdam-cityscape1.jpg";
// String imagePath = "classpath:/images/amsterdam-channel.jpg";
// String imagePath = "classpath:/images/landsmeer.png";
// ADE20K
String tensorflowModelLocation = "file:/Users/ctzolov/Downloads/deeplabv3_xception_ade20k_train/frozen_inference_graph.pb";
String imagePath = "classpath:/images/interior.jpg";
BufferedImage inputImage = ImageIO.read(new DefaultResourceLoader().getResource(imagePath).getInputStream());
TensorFlowService tf = new TensorFlowService(new DefaultResourceLoader().getResource(tensorflowModelLocation),
Arrays.asList(OUTPUT_TENSOR_NAME));
SemanticSegmentationUtils segmentationService = new SemanticSegmentationUtils();
BufferedImage scaledImage = segmentationService.scaledImage(inputImage);
Tensor<UInt8> inTensor = segmentationService.createInputTensor(scaledImage);
Map<String, Tensor<?>> output = tf.apply(Collections.singletonMap(INPUT_TENSOR_NAME, inTensor));
Tensor<?> maskPixelsTensor = output.get(OUTPUT_TENSOR_NAME);
int height = (int) maskPixelsTensor.shape()[1];
int width = (int) maskPixelsTensor.shape()[2];
long[][] maskPixels = maskPixelsTensor.copyTo(new long[BATCH_SIZE][height][width])[0]; // take
// 0
// because
// the
// batch
// size
// is
// 1.
int[][] maskPixelsInt = segmentationService.toIntArray(maskPixels);
BufferedImage maskImage = segmentationService.createMaskImage(maskPixelsInt, scaledImage.getWidth(),
scaledImage.getHeight(), 0.35);
BufferedImage blended = segmentationService.blendMask(maskImage, scaledImage);
ImageIO.write(maskImage, "png", new File("./semantic-segmentation/target/java2Dmask.jpg"));
ImageIO.write(blended, "png", new File("./semantic-segmentation/target/java2Dblended.jpg"));
}
}

View File

@@ -1,156 +0,0 @@
{
"name" : "ade20k",
"info" : "ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/)",
"colormap" :[
[ 0, 0, 0 ],
[ 120, 120, 120 ],
[ 180, 120, 120 ],
[ 6, 230, 230 ],
[ 80, 50, 50 ],
[ 4, 200, 3 ],
[ 120, 120, 80 ],
[ 140, 140, 140 ],
[ 204, 5, 255 ],
[ 230, 230, 230 ],
[ 4, 250, 7 ],
[ 224, 5, 255 ],
[ 235, 255, 7 ],
[ 150, 5, 61 ],
[ 120, 120, 70 ],
[ 8, 255, 51 ],
[ 255, 6, 82 ],
[ 143, 255, 140 ],
[ 204, 255, 4 ],
[ 255, 51, 7 ],
[ 204, 70, 3 ],
[ 0, 102, 200 ],
[ 61, 230, 250 ],
[ 255, 6, 51 ],
[ 11, 102, 255 ],
[ 255, 7, 71 ],
[ 255, 9, 224 ],
[ 9, 7, 230 ],
[ 220, 220, 220 ],
[ 255, 9, 92 ],
[ 112, 9, 255 ],
[ 8, 255, 214 ],
[ 7, 255, 224 ],
[ 255, 184, 6 ],
[ 10, 255, 71 ],
[ 255, 41, 10 ],
[ 7, 255, 255 ],
[ 224, 255, 8 ],
[ 102, 8, 255 ],
[ 255, 61, 6 ],
[ 255, 194, 7 ],
[ 255, 122, 8 ],
[ 0, 255, 20 ],
[ 255, 8, 41 ],
[ 255, 5, 153 ],
[ 6, 51, 255 ],
[ 235, 12, 255 ],
[ 160, 150, 20 ],
[ 0, 163, 255 ],
[ 140, 140, 140 ],
[ 250, 10, 15 ],
[ 20, 255, 0 ],
[ 31, 255, 0 ],
[ 255, 31, 0 ],
[ 255, 224, 0 ],
[ 153, 255, 0 ],
[ 0, 0, 255 ],
[ 255, 71, 0 ],
[ 0, 235, 255 ],
[ 0, 173, 255 ],
[ 31, 0, 255 ],
[ 11, 200, 200 ],
[ 255, 82, 0 ],
[ 0, 255, 245 ],
[ 0, 61, 255 ],
[ 0, 255, 112 ],
[ 0, 255, 133 ],
[ 255, 0, 0 ],
[ 255, 163, 0 ],
[ 255, 102, 0 ],
[ 194, 255, 0 ],
[ 0, 143, 255 ],
[ 51, 255, 0 ],
[ 0, 82, 255 ],
[ 0, 255, 41 ],
[ 0, 255, 173 ],
[ 10, 0, 255 ],
[ 173, 255, 0 ],
[ 0, 255, 153 ],
[ 255, 92, 0 ],
[ 255, 0, 255 ],
[ 255, 0, 245 ],
[ 255, 0, 102 ],
[ 255, 173, 0 ],
[ 255, 0, 20 ],
[ 255, 184, 184 ],
[ 0, 31, 255 ],
[ 0, 255, 61 ],
[ 0, 71, 255 ],
[ 255, 0, 204 ],
[ 0, 255, 194 ],
[ 0, 255, 82 ],
[ 0, 10, 255 ],
[ 0, 112, 255 ],
[ 51, 0, 255 ],
[ 0, 194, 255 ],
[ 0, 122, 255 ],
[ 0, 255, 163 ],
[ 255, 153, 0 ],
[ 0, 255, 10 ],
[ 255, 112, 0 ],
[ 143, 255, 0 ],
[ 82, 0, 255 ],
[ 163, 255, 0 ],
[ 255, 235, 0 ],
[ 8, 184, 170 ],
[ 133, 0, 255 ],
[ 0, 255, 92 ],
[ 184, 0, 255 ],
[ 255, 0, 31 ],
[ 0, 184, 255 ],
[ 0, 214, 255 ],
[ 255, 0, 112 ],
[ 92, 255, 0 ],
[ 0, 224, 255 ],
[ 112, 224, 255 ],
[ 70, 184, 160 ],
[ 163, 0, 255 ],
[ 153, 0, 255 ],
[ 71, 255, 0 ],
[ 255, 0, 163 ],
[ 255, 204, 0 ],
[ 255, 0, 143 ],
[ 0, 255, 235 ],
[ 133, 255, 0 ],
[ 255, 0, 235 ],
[ 245, 0, 255 ],
[ 255, 0, 122 ],
[ 255, 245, 0 ],
[ 10, 190, 212 ],
[ 214, 255, 0 ],
[ 0, 204, 255 ],
[ 20, 0, 255 ],
[ 255, 255, 0 ],
[ 0, 153, 255 ],
[ 0, 41, 255 ],
[ 0, 255, 204 ],
[ 41, 0, 255 ],
[ 41, 255, 0 ],
[ 173, 0, 255 ],
[ 0, 245, 255 ],
[ 71, 0, 255 ],
[ 122, 0, 255 ],
[ 0, 255, 184 ],
[ 0, 92, 255 ],
[ 184, 255, 0 ],
[ 0, 133, 255 ],
[ 255, 214, 0 ],
[ 25, 194, 194 ],
[ 102, 255, 0 ],
[ 92, 0, 255 ]]
}

View File

@@ -1,8 +0,0 @@
{
"name" : "black_white",
"info" : "Black and white color map",
"colormap" :[
[ 0, 0, 0 ],
[ 127, 127, 127 ],
[ 255, 255, 255 ]]
}

View File

@@ -1,24 +0,0 @@
{
"name" : "citymap",
"info" : "Cityscapes dataset (https://www.cityscapes-dataset.com).",
"colormap" :[
[ 128, 64, 128 ],
[ 244, 35, 232 ],
[ 70, 70, 70 ],
[ 102, 102, 156 ],
[ 190, 153, 153 ],
[ 153, 153, 153 ],
[ 250, 170, 30 ],
[ 220, 220, 0 ],
[ 107, 142, 35 ],
[ 152, 251, 152 ],
[ 70, 130, 180 ],
[ 220, 20, 60 ],
[ 255, 0, 0 ],
[ 0, 0, 142 ],
[ 0, 0, 70 ],
[ 0, 60, 100 ],
[ 0, 80, 100 ],
[ 0, 0, 230 ],
[ 119, 11, 32 ]]
}

View File

@@ -1,71 +0,0 @@
{
"name" : "mapillary",
"info" : "Mapillary Vistas (https://research.mapillary.com).",
"colormap" :[
[ 165, 42, 42 ],
[ 0, 192, 0 ],
[ 196, 196, 196 ],
[ 190, 153, 153 ],
[ 180, 165, 180 ],
[ 102, 102, 156 ],
[ 102, 102, 156 ],
[ 128, 64, 255 ],
[ 140, 140, 200 ],
[ 170, 170, 170 ],
[ 250, 170, 160 ],
[ 96, 96, 96 ],
[ 230, 150, 140 ],
[ 128, 64, 128 ],
[ 110, 110, 110 ],
[ 244, 35, 232 ],
[ 150, 100, 100 ],
[ 70, 70, 70 ],
[ 150, 120, 90 ],
[ 220, 20, 60 ],
[ 255, 0, 0 ],
[ 255, 0, 0 ],
[ 255, 0, 0 ],
[ 200, 128, 128 ],
[ 255, 255, 255 ],
[ 64, 170, 64 ],
[ 128, 64, 64 ],
[ 70, 130, 180 ],
[ 255, 255, 255 ],
[ 152, 251, 152 ],
[ 107, 142, 35 ],
[ 0, 170, 30 ],
[ 255, 255, 128 ],
[ 250, 0, 30 ],
[ 0, 0, 0 ],
[ 220, 220, 220 ],
[ 170, 170, 170 ],
[ 222, 40, 40 ],
[ 100, 170, 30 ],
[ 40, 40, 40 ],
[ 33, 33, 33 ],
[ 170, 170, 170 ],
[ 0, 0, 142 ],
[ 170, 170, 170 ],
[ 210, 170, 100 ],
[ 153, 153, 153 ],
[ 128, 128, 128 ],
[ 0, 0, 142 ],
[ 250, 170, 30 ],
[ 192, 192, 192 ],
[ 220, 220, 0 ],
[ 180, 165, 180 ],
[ 119, 11, 32 ],
[ 0, 0, 142 ],
[ 0, 60, 100 ],
[ 0, 0, 142 ],
[ 0, 0, 90 ],
[ 0, 0, 230 ],
[ 0, 80, 100 ],
[ 128, 64, 64 ],
[ 0, 0, 110 ],
[ 0, 0, 70 ],
[ 0, 0, 192 ],
[ 32, 32, 32 ],
[ 0, 0, 0 ],
[ 0, 0, 0 ]]
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 184 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 243 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 236 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 149 KiB