GH-20: Add Computer Vision function

Fixes: #20

* Update djl spring to `0.26`
This commit is contained in:
Christian Tzolov
2024-01-22 12:25:12 -05:00
committed by Artem Bilan
parent 77112eb8a1
commit 55f09da388
20 changed files with 1223 additions and 1 deletions

2
.gitignore vendored
View File

@@ -6,4 +6,4 @@ bin/
.gradle/
.idea/
out/
.vscode/

View File

@@ -59,6 +59,7 @@ allprojects {
mavenBom "io.awspring.cloud:spring-cloud-aws-dependencies:$springCloudAwsVersion"
mavenBom "org.springframework.boot:spring-boot-dependencies:$springBootVersion"
mavenBom "org.springframework.cloud:spring-cloud-dependencies:$springCloudVersion"
mavenBom "ai.djl:bom:$djlVersion"
}
}
}

View File

@@ -4,6 +4,8 @@ ext {
springCloudAwsVersion = '3.0.4'
debeziumVersion = '2.5.2.Final'
djlVersion = '0.26.0'
djlSpringVersion = '0.26'
springIntegrationAws = 'org.springframework.integration:spring-integration-aws:3.0.5'
ftpserverCore = 'org.apache.ftpserver:ftpserver-core:1.2.0'
@@ -11,4 +13,5 @@ ext {
twitter4jStream = 'org.twitter4j:twitter4j-stream:4.0.7'
greenmail = 'com.icegreen:greenmail:2.1.0-alpha-4'
apacheCuratorTest = 'org.apache.curator:curator-test:5.5.0'
djlSpringAutoconfigure = "ai.djl.spring:djl-spring-boot-starter-autoconfigure:$djlSpringVersion"
}

View File

@@ -0,0 +1,220 @@
= Computer Vision Functions
This module provides functional interface to perform common Computer Vision tasks such as Image Classification, Object Detection, Instance and Semantic Segmentation, Pose Estimation an more.
It leverages the https://docs.djl.ai/index.html[Deep Java Library] (DJL) to enable Java developers to harness the power of deep learning.
DJL serves as a bridge between the rich ecosystem of Java programming and the cutting-edge capabilities of deep learning.
DJL provides integration with popular deep learning frameworks like `TensorFlow`, `PyTorch`, and `MXNet`, as well as support for a variety of pre-trained models using `ONNX Runtime`.
== Beans for injection
This module exposes auto-configuration for the following bean:
`Function<Message<byte[]>, Message<byte[]>> computerVisionFunction`
However, the `ComputerVisionFunctionConfiguration` provides a set of conditional beans based on specific configuration properties.
[%autowidth]
|===
|Bean |Activation Properties
|objectDetection
|djl.output-class=ai.djl.modality.cv.output.DetectedObjects
|imageClassifications
|djl.output-class=ai.djl.modality.Classifications
|semanticSegmentation
|djl.output-class=ai.djl.modality.cv.output.CategoryMask
|poseEstimation
|djl.output-class=ai.djl.modality.cv.output.Joints
|===
* `objectDetection` - Offering `Object Detection` for finding all instances of objects from a known set of categories in an image and `Instance Segmentation` for finding all instances of objects from a known set of categories in an image and drawing a mask on each instance.
* `imageClassifications` - The `Image Classification` task assigns a label to an image from a set of categories.
* `semanticSegmentation` - `Semantic Segmentation` refers to the task of detecting objects of various classes at pixel level.
It colors the pixels based on the objects detected in that space.
* `poseEstimation` - `Pose Estimation` refers to the task of detecting human figures in images and videos, and estimating the pose of the bodies.
Once injected, you can use the `apply` method of the `Function` to invoke it and get the result.
The function takes and returns a `Message<byte[]>`.
The input message payload contains the image bytes to be processed.
The output message payload contains the original or the augmented image after the processing.
The `computer.vision.function.augment-enabled` property controls whether the augmented image is returned or not.
Defaults to `true`.
== Configuration Options
[%autowidth]
|===
|Property |Description
|djl.application-type
|Defines the CV application task to be performed. Currently supported values are `OBJECT_DETECTION`, `IMAGE_CLASSIFICATION`, `INSTANCE_SEGMENTATION`, `SEMANTIC_SEGMENTATION` and `POSE_ESTIMATION`.
|djl.input-class
|Define input data type, a model may accept multiple input data type. Currently only the `ai.djl.modality.cv.Image` is supported.
|djl.output-class
|Define output data type, a model may generate different outputs. Supported output classes are `ai.djl.modality.cv.output.DetectedObjects`, `ai.djl.modality.cv.output.CategoryMask`, `ai.djl.modality.Classifications`, `ai.djl.modality.cv.output.Joints` .
|djl.urls
|Model repository URLs. Multiple may be supplied to search for models. Specifying a single URL can be used to load a specific model. Can be specified as comma delimited field or as an array in the configuration file.
Current supported archive formats: `zip`, `tar`, `tar.gz`, `tgz`, `tar.z`.
Supported URL schemes: `file://` - load a model from local directory or archive file., `http(s)://` - load a model from an archive file from web server, `jar://` - load a model from an archive file in the class path, `djl://` - load a model from the model zoo, `s3://` - load a model from S3 bucket (requires djl aws extension), `hdfs://` - load a model from HDFS file system (requires djl hadoop extension)
|djl.model-filter
| https://github.com/deepjavalibrary/djl/tree/master/model-zoo#how-to-find-a-pre-trained-model-in-the-model-zoo[Model Filters] used to lookup a model from model zoo .
|djl.group-id
|Defines the `groupId` of the model to be loaded from the zoo.
|djl.model-artifact-id
|Defines the `artifactId` of the model to be loaded from the zoo.
|djl.model-name
|(Optional) Defines the modelName of the model to be loaded.
Leave it empty if you want to load the latest version of the model.
Use "saved_model" for TensorFlow saved models.
|djl.engine
| Name of teh https://docs.djl.ai/docs/engine.html[Engine] to use https://docs.djl.ai/docs/engine.html#supported-engines[Supported engine names].
|djl.translator-factory
| https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html[Translator] provides model pre-processing and postprocessing functionality. Multiple https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/translator/package-summary.html[translators] are provided for different models, but you can implement your own translator if needed (see []). The translator-factory property allow to specify the translator to be used with the model.
|computer.vision.function.output-header-name
|Name of the header that contains the JSON payload computed by the functions.
|computer.vision.function.augment-enabled
|Enable image augmentation (false by default).
|===
Also, this function exposes its specific properties with a `computer.vision.function` prefix.
See link:src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionProperties.java[ComputerVisionFunctionProperties] for more details.
=== Example Configurations
All computer vision examples use the following Java code snippet to invoke the function:
[source,Java]
----
@SpringBootApplication
public class TfObjectDetectionBootApp implements CommandLineRunner {
@Autowired
private Function<Message<byte[]>, Message<byte[]>> cvFunction;
@Override
public void run(String... args) throws Exception {
byte[] inputImage = new ClassPathResource("Image URI").getInputStream().readAllBytes();
Message<byte[]> outputMessage = cvFunction.apply(
MessageBuilder.withPayload(inputImage).build());
// Augmented output image.
byte[] outputImage = outputMessage.getPayload();
// JSON payload with the detected objects and their bounding boxes.
String jsonBoundingBoxes = outputMessage.getHeader("cvjson", String.class);
}
public static void main(String[] args) {
SpringApplication.run(TfObjectDetectionBootApp.class);
}
}
----
==== Object Detection (TensorFlow)
You can leverage any of the existing TensorFlow models.
Just comply the url of the model archive as a `djl.urls` property and set the `djl.translator-factory` to `org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory`.
----
computer.vision.function.augment-enabled=true
djl.application-type=OBJECT_DETECTION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
djl.engine=TensorFlow
djl.urls=http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz
djl.model-name=saved_model
djl.translator-factory=org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory
djl.arguments.threshold=0.3
----
==== Object Detection (Yolo v8)
You can use the same Java snipped above, just change the configuration to use the Yolo v8 model:
----
computer.vision.function.augment-enabled=true
djl.application-type=OBJECT_DETECTION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
djl.engine=OnnxRuntime
djl.urls=djl://ai.djl.onnxruntime/yolov8n
djl.translator-factory=ai.djl.modality.cv.translator.YoloV8TranslatorFactory
djl.arguments.threshold=0.3
djl.arguments.width=640
djl.arguments.height=640
djl.arguments.resize=true
djl.arguments.toTensor=true
djl.arguments.applyRatio=true
djl.arguments.maxBox=1000
----
==== Instance Segmentation
Same Java code snipped but with the following configuration:
----
computer.vision.function.augment-enabled=true
djl.application-type=INSTANCE_SEGMENTATION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
djl.arguments.threshold=0.3
djl.model-filter.backbone=resnet18
djl.model-filter.flavor=v1b
djl.model-filter.dataset=coco
----
Note that here we didn't specify the model to be used, but used the model-filter to find a compatible model from the model zoo.
==== Semantic Segmentation
Same Java code snipped but with the following configuration:
----
computer.vision.function.augment-enabled=true
djl.application-type=SEMANTIC_SEGMENTATION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.cv.output.CategoryMask
djl.arguments.threshold=0.3
djl.urls=https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip
djl.translator-factory=ai.djl.modality.cv.translator.SemanticSegmentationTranslatorFactory
djl.engine=PyTorch
----
==== Image Classification
----
djl.application-type=IMAGE_CLASSIFICATION
djl.input-class=ai.djl.modality.cv.Image
djl.output-class=ai.djl.modality.Classifications
djl.arguments.threshold=0.3
djl.engine=MXNet
----
== Tests
See this link:src/test/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfigurationTests.java[test suite] for examples of how this function is used.
The link:src/test/java/org/springframework/cloud/fn/computer/vision/JsonHelperTests.java[JsonHelperTests] validates the JSON serialization and deserialization of the `ComputerVisionFunctionConfiguration` class values object classes.

View File

@@ -0,0 +1,7 @@
dependencies {
api djlSpringAutoconfigure
api "ai.djl.spring:djl-spring-boot-starter-tensorflow-auto:$djlSpringVersion"
api "ai.djl.spring:djl-spring-boot-starter-pytorch-auto:$djlSpringVersion"
api "ai.djl.spring:djl-spring-boot-starter-mxnet-auto:$djlSpringVersion"
runtimeOnly 'ai.djl.onnxruntime:onnxruntime-engine'
}

View File

@@ -0,0 +1,163 @@
/*
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.computer.vision;
import java.awt.image.RenderedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.imageio.ImageIO;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.spring.configuration.DjlAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
/**
* A configuration class that provides the necessary beans for the Computer Vision
* Function.
*
* @author Christian Tzolov
*/
@AutoConfiguration(after = DjlAutoConfiguration.class)
@EnableConfigurationProperties(ComputerVisionFunctionProperties.class)
public class ComputerVisionFunctionConfiguration {
private static final ImageFactory IMAGE_FACTORY = ImageFactory.getInstance();
private final Supplier<Predictor<?, ?>> predictorProvider;
private final ComputerVisionFunctionProperties cvProperties;
public ComputerVisionFunctionConfiguration(Supplier<Predictor<?, ?>> predictorProvider,
ComputerVisionFunctionProperties cvProperties) {
this.predictorProvider = predictorProvider;
this.cvProperties = cvProperties;
}
@Bean(name = "computerVisionFunction")
@ConditionalOnProperty(prefix = "djl", name = "output-class",
havingValue = "ai.djl.modality.cv.output.DetectedObjects")
public Function<Message<byte[]>, Message<byte[]>> objectDetection() {
BiFunction<DetectedObjects, Image, byte[]> augmentImage = (detectedObjects, image) -> {
Image newImage = image.duplicate();
newImage.drawBoundingBoxes(detectedObjects);
return getByteArray((RenderedImage) newImage.getWrappedImage(),
this.cvProperties.getOutputImageFormatName());
};
return predictor(JsonHelper::toJson, augmentImage);
}
@Bean(name = "computerVisionFunction")
@ConditionalOnProperty(prefix = "djl", name = "output-class",
havingValue = "ai.djl.modality.cv.output.CategoryMask")
public Function<Message<byte[]>, Message<byte[]>> semanticSegmentation() {
BiFunction<CategoryMask, Image, byte[]> augmentImage = (mask, image) -> {
Image newImage = image.duplicate();
mask.drawMask(newImage, 200, 0);
return getByteArray((RenderedImage) newImage.getWrappedImage(),
this.cvProperties.getOutputImageFormatName());
};
return predictor(JsonHelper::toJson, augmentImage);
}
@Bean(name = "computerVisionFunction")
@ConditionalOnProperty(prefix = "djl", name = "output-class", havingValue = "ai.djl.modality.Classifications")
public Function<Message<byte[]>, Message<byte[]>> imageClassifications() {
BiFunction<Classifications, Image, byte[]> augmentImage = (classifications, image) -> {
Image newImage = image.duplicate();
return getByteArray((RenderedImage) newImage.getWrappedImage(),
this.cvProperties.getOutputImageFormatName());
};
return predictor(JsonHelper::toJson, augmentImage);
}
@Bean(name = "computerVisionFunction")
@ConditionalOnProperty(prefix = "djl", name = "output-class", havingValue = "ai.djl.modality.cv.output.Joints")
public Function<Message<byte[]>, Message<byte[]>> poseEstimation() {
BiFunction<Joints, Image, byte[]> augmentImage = (joints, image) -> {
Image newImage = image.duplicate();
newImage.drawJoints(joints);
return getByteArray((RenderedImage) newImage.getWrappedImage(),
this.cvProperties.getOutputImageFormatName());
};
return predictor(JsonHelper::toJson, augmentImage);
}
@SuppressWarnings("unchecked")
private <T> Function<Message<byte[]>, Message<byte[]>> predictor(Function<T, String> toJsonFunction,
BiFunction<T, Image, byte[]> augmentImageFunction) {
return (input) -> {
try (Predictor<Image, T> predictor = (Predictor<Image, T>) this.predictorProvider.get()) {
Image image = IMAGE_FACTORY.fromInputStream(new ByteArrayInputStream(input.getPayload()));
T output = predictor.predict(image);
String outputJson = toJsonFunction.apply(output);
byte[] outputImageBytes = input.getPayload();
if (this.cvProperties.isAugmentEnabled()) {
outputImageBytes = augmentImageFunction.apply(output, image);
}
String headerName = this.cvProperties.getOutputHeaderName();
return MessageBuilder.withPayload(outputImageBytes).setHeader(headerName, outputJson).build();
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
};
}
private static byte[] getByteArray(RenderedImage image, String formatName) {
try {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ImageIO.write(image, formatName, byteArrayOutputStream);
return byteArrayOutputStream.toByteArray();
}
catch (IOException ex) {
throw new UncheckedIOException(ex);
}
}
}

View File

@@ -0,0 +1,68 @@
/*
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.computer.vision;
import org.springframework.boot.context.properties.ConfigurationProperties;
/**
* Configuration properties for the Computer Vision Function.
*
* @author Christian Tzolov
*/
@ConfigurationProperties("computer.vision.function")
public class ComputerVisionFunctionProperties {
/**
* Enable image augmentation.
*/
private boolean augmentEnabled = false;
/**
* Output augmented image format name.
*/
private String outputImageFormatName = "png";
/**
* Name of the header that contains the JSON payload computed by the functions.
*/
private String outputHeaderName = "cvjson";
public boolean isAugmentEnabled() {
return this.augmentEnabled;
}
public void setAugmentEnabled(boolean augmentImage) {
this.augmentEnabled = augmentImage;
}
public String getOutputImageFormatName() {
return this.outputImageFormatName;
}
public void setOutputImageFormatName(String outputImageFormatName) {
this.outputImageFormatName = outputImageFormatName;
}
public String getOutputHeaderName() {
return this.outputHeaderName;
}
public void setOutputHeaderName(String jsonHeaderName) {
this.outputHeaderName = jsonHeaderName;
}
}

View File

@@ -0,0 +1,112 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.computer.vision;
import java.lang.reflect.Type;
import java.util.List;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.util.JsonUtils;
import com.google.gson.Gson;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonParseException;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
/**
* Helper class to serialize and deserialize {@link DetectedObjects},
* {@link Classifications}, {@link CategoryMask} and {@link Joints} to/from JSON.
*
* @author Christian Tzolov
*/
public final class JsonHelper {
private static final Gson GSON = JsonUtils.builder().create();
private JsonHelper() {
}
public static String toJson(Joints joints) {
return GSON.toJson(joints);
}
public static Joints toJoints(String json) {
return GSON.fromJson(json, Joints.class);
}
public static String toJson(CategoryMask categoryMask) {
return GSON.toJson(Mask.fromCategoryMask(categoryMask));
}
public static CategoryMask toCategoryMask(String json) {
return GSON.fromJson(json, Mask.class).toCategoryMask();
}
public static String toJson(Classifications classifications) {
return GSON.toJson(classifications);
}
public static Classifications toClassifications(String json) {
return GSON.fromJson(json, Classifications.class);
}
private static final Gson GSON2 = JsonUtils.builder()
.registerTypeAdapter(BoundingBox.class, new BoundingBoxAdapter())
.create();
public static String toJson(DetectedObjects detectedObject) {
return GSON2.toJson(detectedObject);
}
public static DetectedObjects toDetectedObjects(String json) {
return GSON2.fromJson(json, DetectedObjects.class);
}
public record Mask(List<String> classes, int[][] mask) {
public static Mask fromCategoryMask(CategoryMask categoryMask) {
return new Mask(categoryMask.getClasses(), categoryMask.getMask());
}
public CategoryMask toCategoryMask() {
return new CategoryMask(this.classes, this.mask);
}
}
public static class BoundingBoxAdapter implements JsonSerializer<BoundingBox>, JsonDeserializer<BoundingBox> {
@Override
public JsonElement serialize(BoundingBox boundingBox, Type typeOfSrc, JsonSerializationContext context) {
return context.serialize(boundingBox);
}
@Override
public BoundingBox deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context)
throws JsonParseException {
return context.deserialize(json, Rectangle.class);
}
}
}

View File

@@ -0,0 +1,4 @@
/**
* Provides classes for the Computer Vision Function.
*/
package org.springframework.cloud.fn.computer.vision;

View File

@@ -0,0 +1,184 @@
/*
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.computer.vision.translator;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Scanner;
import java.util.concurrent.ConcurrentHashMap;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import com.google.gson.annotations.SerializedName;
/**
* A {@link NoBatchifyTranslator} that post-processes the output of a TensorFlow
* SavedModel Object Detection model.
*
* @author Christian Tzolov
*/
public final class TensorflowSavedModelObjectDetectionTranslator
implements NoBatchifyTranslator<Image, DetectedObjects> {
private static final String ITEM_DELIMITER = "item ";
private static final String DEFAULT_MSCOCO_LABELS_URL = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
private static final String DETECTION_BOXES = "detection_boxes";
private static final String DETECTION_SCORES = "detection_scores";
private static final String DETECTION_CLASSES = "detection_classes";
private String classLabelsUrl;
private Map<Integer, String> classLabels;
private int maxBoxes;
private float threshold;
public TensorflowSavedModelObjectDetectionTranslator() {
this(DEFAULT_MSCOCO_LABELS_URL, 10, 0.3f);
}
public TensorflowSavedModelObjectDetectionTranslator(String categoryLabelsUrl, int maxBoxes, float threshold) {
this.classLabelsUrl = categoryLabelsUrl;
this.maxBoxes = maxBoxes;
this.threshold = threshold;
}
/** {@inheritDoc} */
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
// input to tf object-detection models is a list of tensors, hence NDList
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
// optionally resize the image for faster processing
array = NDImageUtils.resize(array, 224);
// tf object-detection models expect 8 bit unsigned integer tensor
array = array.toType(DataType.UINT8, true);
// tf object-detection models expect a 4 dimensional input
array = array.expandDims(0);
return new NDList(array);
}
/** {@inheritDoc} */
@Override
public void prepare(TranslatorContext ctx) throws IOException {
if (this.classLabels == null) {
this.classLabels = loadSynset();
}
}
private Map<Integer, String> loadSynset() throws IOException {
Map<Integer, String> map = new ConcurrentHashMap<>();
int maxId = 0;
try (InputStream is = new BufferedInputStream(new URL(this.classLabelsUrl).openStream());
Scanner scanner = new Scanner(is, StandardCharsets.UTF_8)) {
scanner.useDelimiter(ITEM_DELIMITER);
while (scanner.hasNext()) {
String content = scanner.next();
content = content.replaceAll("(\"|\\d)\\n\\s", "$1,");
Item item = JsonUtils.GSON.fromJson(content, Item.class);
map.put(item.id, item.displayName);
if (item.id > maxId) {
maxId = item.id;
}
}
}
return map;
}
/** {@inheritDoc} */
@Override
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
// output of tf object-detection models is a list of tensors, hence NDList in djl
// output NDArray order in the list are not guaranteed
int[] classIds = null;
float[] probabilities = null;
NDArray boundingBoxes = null;
for (NDArray array : list) {
if (DETECTION_BOXES.equals(array.getName())) {
boundingBoxes = array.get(0);
}
else if (DETECTION_SCORES.equals(array.getName())) {
probabilities = array.get(0).toFloatArray();
}
else if (DETECTION_CLASSES.equals(array.getName())) {
// class id is between 1 - number of classes
classIds = array.get(0).toType(DataType.INT32, true).toIntArray();
}
}
Objects.requireNonNull(classIds);
Objects.requireNonNull(probabilities);
Objects.requireNonNull(boundingBoxes);
List<String> retNames = new ArrayList<>();
List<Double> retProbs = new ArrayList<>();
List<BoundingBox> retBB = new ArrayList<>();
// result are already sorted
for (int i = 0; i < Math.min(classIds.length, this.maxBoxes); ++i) {
int classId = classIds[i];
double probability = probabilities[i];
// classId starts from 1, -1 means background
if (classId > 0 && probability > this.threshold) {
String className = this.classLabels.getOrDefault(classId, "#" + classId);
float[] box = boundingBoxes.get(i).toFloatArray();
float yMin = box[0];
float xMin = box[1];
float yMax = box[2];
float xMax = box[3];
Rectangle rect = new Rectangle(xMin, yMin, xMax - xMin, yMax - yMin);
retNames.add(className);
retProbs.add(probability);
retBB.add(rect);
}
}
return new DetectedObjects(retNames, retProbs, retBB);
}
private static final class Item {
int id;
@SerializedName("display_name")
String displayName;
}
}

View File

@@ -0,0 +1,39 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.computer.vision.translator;
import java.util.Map;
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.translator.ObjectDetectionTranslatorFactory;
import ai.djl.translate.Translator;
/**
* Translator for TensorFlow Object Detection SavedModel.
*
* @author Christian Tzolov
*/
public class TensorflowSavedModelObjectDetectionTranslatorFactory extends ObjectDetectionTranslatorFactory {
@Override
protected Translator<Image, DetectedObjects> buildBaseTranslator(Model model, Map<String, ?> arguments) {
return new TensorflowSavedModelObjectDetectionTranslator();
}
}

View File

@@ -0,0 +1,4 @@
/**
* Provides classes for translating the output of the computer vision function.
*/
package org.springframework.cloud.fn.computer.vision.translator;

View File

@@ -0,0 +1 @@
org.springframework.cloud.fn.computer.vision.ComputerVisionFunctionConfiguration

View File

@@ -0,0 +1,331 @@
/*
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.computer.vision;
import java.util.function.Function;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.translator.SemanticSegmentationTranslatorFactory;
import ai.djl.modality.cv.translator.YoloV8TranslatorFactory;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.spring.configuration.ApplicationType;
import ai.djl.spring.configuration.DjlAutoConfiguration;
import ai.djl.spring.configuration.DjlConfigurationProperties;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnOs;
import org.junit.jupiter.api.condition.OS;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory;
import org.springframework.core.io.ClassPathResource;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import static org.assertj.core.api.Assertions.assertThat;
public class ComputerVisionFunctionConfigurationTests {
private ApplicationContextRunner applicationContextRunner;
@BeforeEach
public void setUp() {
applicationContextRunner = new ApplicationContextRunner().withConfiguration(
AutoConfigurations.of(DjlAutoConfiguration.class, ComputerVisionFunctionConfiguration.class));
}
/**
* This configuration can be used to load any of the Tensorflow2 models for object
* detection from here:
* https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md
*/
@Test
public void tf2SavedModel() {
applicationContextRunner.withPropertyValues(
// @formatter:off
"computer.vision.function.augment-enabled=true",
"djl.application-type=" + ApplicationType.OBJECT_DETECTION,
"djl.input-class=" + Image.class.getName(),
"djl.output-class=" + DetectedObjects.class.getName(),
"djl.engine=TensorFlow",
"djl.urls=http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz",
"djl.model-name=saved_model",
"djl.translator-factory=" + TensorflowSavedModelObjectDetectionTranslatorFactory.class.getName(),
"djl.arguments.threshold=0.3")
// @formatter:on
.run((context) -> {
assertThat(context).hasSingleBean(ZooModel.class);
assertThat(context).hasBean("predictorProvider");
@SuppressWarnings("unchecked")
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
.getBean("computerVisionFunction");
var djlProperties = context.getBean(DjlConfigurationProperties.class);
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.OBJECT_DETECTION);
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class);
assertThat(djlProperties.getEngine()).isEqualTo("TensorFlow");
assertThat(djlProperties.getUrls()).contains(
"http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz");
assertThat(djlProperties.getModelName()).isEqualTo("saved_model");
assertThat(djlProperties.getTranslatorFactory())
.isEqualTo(TensorflowSavedModelObjectDetectionTranslatorFactory.class.getName());
byte[] inputImage = new ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes();
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
assertThat(outputMessage).isNotNull();
assertThat(outputMessage.getPayload()).isNotNull();
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
var json = outputMessage.getHeaders().get("cvjson", String.class);
assertThat(JsonHelper.toDetectedObjects(json)).isNotNull();
});
}
@Test
public void yolov8Detection() {
applicationContextRunner.withPropertyValues(
// @formatter:off
"computer.vision.function.augment-enabled=true",
"djl.application-type=" + ApplicationType.OBJECT_DETECTION,
"djl.input-class=" + Image.class.getName(),
"djl.output-class=" + DetectedObjects.class.getName(),
"djl.engine=OnnxRuntime",
"djl.urls=djl://ai.djl.onnxruntime/yolov8n",
"djl.translator-factory=" + YoloV8TranslatorFactory.class.getName(),
"djl.arguments.threshold=0.3",
"djl.arguments.width=640",
"djl.arguments.height=640",
"djl.arguments.resize=true",
"djl.arguments.toTensor=true",
"djl.arguments.applyRatio=true",
"djl.arguments.maxBox=1000")
// @formatter:on
.run((context) -> {
assertThat(context).hasSingleBean(ZooModel.class);
assertThat(context).hasBean("predictorProvider");
@SuppressWarnings("unchecked")
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
.getBean("computerVisionFunction");
var djlProperties = context.getBean(DjlConfigurationProperties.class);
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.OBJECT_DETECTION);
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class);
assertThat(djlProperties.getEngine()).isEqualTo("OnnxRuntime");
assertThat(djlProperties.getUrls()).contains("djl://ai.djl.onnxruntime/yolov8n");
assertThat(djlProperties.getTranslatorFactory()).isEqualTo(YoloV8TranslatorFactory.class.getName());
byte[] inputImage = new ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes();
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
assertThat(outputMessage).isNotNull();
assertThat(outputMessage.getPayload()).isNotNull();
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
var json = outputMessage.getHeaders().get("cvjson", String.class);
var detectionObjects = JsonHelper.toDetectedObjects(json);
assertThat(detectionObjects).isNotNull();
});
}
@Test
public void instanceSegmentation() {
applicationContextRunner.withPropertyValues(
// @formatter:off
"computer.vision.function.augment-enabled=true",
"djl.application-type=" + ApplicationType.INSTANCE_SEGMENTATION,
"djl.input-class=" + Image.class.getName(),
"djl.output-class=" + DetectedObjects.class.getName(),
"djl.arguments.threshold=0.3",
"djl.model-filter.backbone=resnet18",
"djl.model-filter.flavor=v1b",
"djl.model-filter.dataset=coco")
// @formatter:on
.run((context) -> {
assertThat(context).hasSingleBean(ZooModel.class);
assertThat(context).hasBean("predictorProvider");
@SuppressWarnings("unchecked")
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
.getBean("computerVisionFunction");
var djlProperties = context.getBean(DjlConfigurationProperties.class);
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.INSTANCE_SEGMENTATION);
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class);
// byte[] inputImage = new
// ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes();
byte[] inputImage = new ClassPathResource("/amsterdam-cityscape.jpg").getInputStream().readAllBytes();
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
assertThat(outputMessage).isNotNull();
assertThat(outputMessage.getPayload()).isNotNull();
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
String json = outputMessage.getHeaders().get("cvjson", String.class);
assertThat(JsonHelper.toDetectedObjects(json)).isNotNull();
});
}
@DisabledOnOs(OS.WINDOWS)
@Test
public void semanticSegmentation() {
applicationContextRunner.withPropertyValues(
// @formatter:off
"computer.vision.function.augment-enabled=true",
"djl.application-type=" + ApplicationType.SEMANTIC_SEGMENTATION,
"djl.input-class=" + Image.class.getName(),
"djl.output-class=" + CategoryMask.class.getName(),
"djl.arguments.threshold=0.3",
"djl.urls=https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip",
"djl.translator-factory=" + SemanticSegmentationTranslatorFactory.class.getName(),
"djl.engine=PyTorch")
// @formatter:on
.run((context) -> {
assertThat(context).hasSingleBean(ZooModel.class);
assertThat(context).hasBean("predictorProvider");
@SuppressWarnings("unchecked")
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
.getBean("computerVisionFunction");
var djlProperties = context.getBean(DjlConfigurationProperties.class);
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.SEMANTIC_SEGMENTATION);
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
assertThat(djlProperties.getOutputClass()).isEqualTo(CategoryMask.class);
byte[] inputImage = new ClassPathResource("/amsterdam-cityscape.jpg").getInputStream().readAllBytes();
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
assertThat(outputMessage).isNotNull();
assertThat(outputMessage.getPayload()).isNotNull();
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
String ssJson = outputMessage.getHeaders().get("cvjson", String.class);
assertThat(JsonHelper.toCategoryMask(ssJson)).isNotNull();
});
}
@Test
public void imageClassifications() {
applicationContextRunner.withPropertyValues(
// @formatter:off
"computer.vision.function.augment-enabled=false",
"djl.application-type=" + ApplicationType.IMAGE_CLASSIFICATION,
"djl.input-class=" + Image.class.getName(),
"djl.output-class=" + Classifications.class.getName(),
"djl.arguments.threshold=0.3",
"djl.engine=MXNet")
// @formatter:on
.run((context) -> {
assertThat(context).hasSingleBean(ZooModel.class);
assertThat(context).hasBean("predictorProvider");
@SuppressWarnings("unchecked")
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
.getBean("computerVisionFunction");
var djlProperties = context.getBean(DjlConfigurationProperties.class);
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.IMAGE_CLASSIFICATION);
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
assertThat(djlProperties.getOutputClass()).isEqualTo(Classifications.class);
byte[] inputImage = new ClassPathResource("/karakatschan.jpg").getInputStream().readAllBytes();
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
assertThat(outputMessage).isNotNull();
assertThat(outputMessage.getPayload()).isNotNull();
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
String json = outputMessage.getHeaders().get("cvjson", String.class);
assertThat(JsonHelper.toClassifications(json)).isNotNull();
});
}
@Test
public void poseEstimation() {
applicationContextRunner.withPropertyValues(
// @formatter:off
"computer.vision.function.augment-enabled=true",
"djl.application-type=" + ApplicationType.POSE_ESTIMATION,
"djl.input-class=" + Image.class.getName(),
"djl.output-class=" + Joints.class.getName(),
"djl.arguments.threshold=0.3",
"djl.model-filter.backbone=resnet18",
"djl.model-filter.flavor=v1b",
"djl.model-filter.dataset=imagenet")
// @formatter:on
.run((context) -> {
assertThat(context).hasSingleBean(ZooModel.class);
assertThat(context).hasBean("predictorProvider");
@SuppressWarnings("unchecked")
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
.getBean("computerVisionFunction");
var djlProperties = context.getBean(DjlConfigurationProperties.class);
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.POSE_ESTIMATION);
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
assertThat(djlProperties.getOutputClass()).isEqualTo(Joints.class);
byte[] inputImage = new ClassPathResource("/pose.png").getInputStream().readAllBytes();
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
assertThat(outputMessage).isNotNull();
assertThat(outputMessage.getPayload()).isNotNull();
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
String ssJson = outputMessage.getHeaders().get("cvjson", String.class);
assertThat(JsonHelper.toJoints(ssJson)).isNotNull();
});
}
}

View File

@@ -0,0 +1,85 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.computer.vision;
import java.util.List;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
*/
public class JsonHelperTests {
@Test
public void categoryMask() {
var categoryMask = new CategoryMask(List.of("a", "b", "c"), new int[][] { { 1, 2, 3 }, { 4, 5, 6 } });
var json = JsonHelper.toJson(categoryMask);
assertThat(json).isNotEmpty();
var categoryMask2 = JsonHelper.toCategoryMask(json);
assertThat(categoryMask.getClasses()).isEqualTo(categoryMask2.getClasses());
assertThat(categoryMask.getMask()).isEqualTo(categoryMask2.getMask());
}
@Test
public void classifications() {
var classifications = new Classifications(List.of("a", "b", "c"), List.of(0.1, 0.2, 0.3));
classifications.setTopK(3);
var json = JsonHelper.toJson(classifications);
assertThat(json).isNotEmpty();
var classifications2 = JsonHelper.toClassifications(json);
assertThat(classifications2.getClassNames()).isEqualTo(classifications.getClassNames());
assertThat(classifications2.getProbabilities()).isEqualTo(classifications.getProbabilities());
assertThat(classifications2.topK()).hasSize(3);
}
@Test
public void detectedObjects() {
DetectedObjects detectedObjects = new DetectedObjects(List.of("a", "b", "c"), List.of(0.1, 0.2, 0.3),
List.of(new Rectangle(1, 2, 3, 4), new Rectangle(5, 6, 7, 8), new Rectangle(9, 10, 11, 12)));
detectedObjects.setTopK(3);
var json = JsonHelper.toJson(detectedObjects);
assertThat(json).isNotEmpty();
var detectedObjects2 = JsonHelper.toDetectedObjects(json);
assertThat(detectedObjects2.getClassNames()).isEqualTo(detectedObjects.getClassNames());
assertThat(detectedObjects2.getProbabilities()).isEqualTo(detectedObjects.getProbabilities());
assertThat(detectedObjects2.topK()).hasSize(3);
assertThat(detectedObjects2.getNumberOfObjects()).isEqualTo(3);
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1013 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB