GH-20: Add Computer Vision function
Fixes: #20 * Update djl spring to `0.26`
This commit is contained in:
committed by
Artem Bilan
parent
77112eb8a1
commit
55f09da388
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,4 +6,4 @@ bin/
|
||||
.gradle/
|
||||
.idea/
|
||||
out/
|
||||
|
||||
.vscode/
|
||||
|
||||
@@ -59,6 +59,7 @@ allprojects {
|
||||
mavenBom "io.awspring.cloud:spring-cloud-aws-dependencies:$springCloudAwsVersion"
|
||||
mavenBom "org.springframework.boot:spring-boot-dependencies:$springBootVersion"
|
||||
mavenBom "org.springframework.cloud:spring-cloud-dependencies:$springCloudVersion"
|
||||
mavenBom "ai.djl:bom:$djlVersion"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ ext {
|
||||
springCloudAwsVersion = '3.0.4'
|
||||
|
||||
debeziumVersion = '2.5.2.Final'
|
||||
djlVersion = '0.26.0'
|
||||
djlSpringVersion = '0.26'
|
||||
|
||||
springIntegrationAws = 'org.springframework.integration:spring-integration-aws:3.0.5'
|
||||
ftpserverCore = 'org.apache.ftpserver:ftpserver-core:1.2.0'
|
||||
@@ -11,4 +13,5 @@ ext {
|
||||
twitter4jStream = 'org.twitter4j:twitter4j-stream:4.0.7'
|
||||
greenmail = 'com.icegreen:greenmail:2.1.0-alpha-4'
|
||||
apacheCuratorTest = 'org.apache.curator:curator-test:5.5.0'
|
||||
djlSpringAutoconfigure = "ai.djl.spring:djl-spring-boot-starter-autoconfigure:$djlSpringVersion"
|
||||
}
|
||||
|
||||
220
function/spring-computer-vision-function/README.adoc
Normal file
220
function/spring-computer-vision-function/README.adoc
Normal file
@@ -0,0 +1,220 @@
|
||||
= Computer Vision Functions
|
||||
|
||||
This module provides functional interface to perform common Computer Vision tasks such as Image Classification, Object Detection, Instance and Semantic Segmentation, Pose Estimation an more.
|
||||
|
||||
It leverages the https://docs.djl.ai/index.html[Deep Java Library] (DJL) to enable Java developers to harness the power of deep learning.
|
||||
DJL serves as a bridge between the rich ecosystem of Java programming and the cutting-edge capabilities of deep learning.
|
||||
DJL provides integration with popular deep learning frameworks like `TensorFlow`, `PyTorch`, and `MXNet`, as well as support for a variety of pre-trained models using `ONNX Runtime`.
|
||||
|
||||
== Beans for injection
|
||||
|
||||
This module exposes auto-configuration for the following bean:
|
||||
|
||||
`Function<Message<byte[]>, Message<byte[]>> computerVisionFunction`
|
||||
|
||||
However, the `ComputerVisionFunctionConfiguration` provides a set of conditional beans based on specific configuration properties.
|
||||
|
||||
[%autowidth]
|
||||
|===
|
||||
|Bean |Activation Properties
|
||||
|
||||
|objectDetection
|
||||
|djl.output-class=ai.djl.modality.cv.output.DetectedObjects
|
||||
|
||||
|imageClassifications
|
||||
|djl.output-class=ai.djl.modality.Classifications
|
||||
|
||||
|semanticSegmentation
|
||||
|djl.output-class=ai.djl.modality.cv.output.CategoryMask
|
||||
|
||||
|poseEstimation
|
||||
|djl.output-class=ai.djl.modality.cv.output.Joints
|
||||
|
||||
|===
|
||||
|
||||
* `objectDetection` - Offering `Object Detection` for finding all instances of objects from a known set of categories in an image and `Instance Segmentation` for finding all instances of objects from a known set of categories in an image and drawing a mask on each instance.
|
||||
* `imageClassifications` - The `Image Classification` task assigns a label to an image from a set of categories.
|
||||
* `semanticSegmentation` - `Semantic Segmentation` refers to the task of detecting objects of various classes at pixel level.
|
||||
It colors the pixels based on the objects detected in that space.
|
||||
* `poseEstimation` - `Pose Estimation` refers to the task of detecting human figures in images and videos, and estimating the pose of the bodies.
|
||||
|
||||
Once injected, you can use the `apply` method of the `Function` to invoke it and get the result.
|
||||
|
||||
The function takes and returns a `Message<byte[]>`.
|
||||
The input message payload contains the image bytes to be processed.
|
||||
The output message payload contains the original or the augmented image after the processing.
|
||||
The `computer.vision.function.augment-enabled` property controls whether the augmented image is returned or not.
|
||||
Defaults to `true`.
|
||||
|
||||
== Configuration Options
|
||||
|
||||
[%autowidth]
|
||||
|===
|
||||
|Property |Description
|
||||
|
||||
|djl.application-type
|
||||
|Defines the CV application task to be performed. Currently supported values are `OBJECT_DETECTION`, `IMAGE_CLASSIFICATION`, `INSTANCE_SEGMENTATION`, `SEMANTIC_SEGMENTATION` and `POSE_ESTIMATION`.
|
||||
|
||||
|djl.input-class
|
||||
|Define input data type, a model may accept multiple input data type. Currently only the `ai.djl.modality.cv.Image` is supported.
|
||||
|
||||
|djl.output-class
|
||||
|Define output data type, a model may generate different outputs. Supported output classes are `ai.djl.modality.cv.output.DetectedObjects`, `ai.djl.modality.cv.output.CategoryMask`, `ai.djl.modality.Classifications`, `ai.djl.modality.cv.output.Joints` .
|
||||
|
||||
|djl.urls
|
||||
|Model repository URLs. Multiple may be supplied to search for models. Specifying a single URL can be used to load a specific model. Can be specified as comma delimited field or as an array in the configuration file.
|
||||
Current supported archive formats: `zip`, `tar`, `tar.gz`, `tgz`, `tar.z`.
|
||||
|
||||
Supported URL schemes: `file://` - load a model from local directory or archive file., `http(s)://` - load a model from an archive file from web server, `jar://` - load a model from an archive file in the class path, `djl://` - load a model from the model zoo, `s3://` - load a model from S3 bucket (requires djl aws extension), `hdfs://` - load a model from HDFS file system (requires djl hadoop extension)
|
||||
|
||||
|djl.model-filter
|
||||
| https://github.com/deepjavalibrary/djl/tree/master/model-zoo#how-to-find-a-pre-trained-model-in-the-model-zoo[Model Filters] used to lookup a model from model zoo .
|
||||
|
||||
|djl.group-id
|
||||
|Defines the `groupId` of the model to be loaded from the zoo.
|
||||
|
||||
|djl.model-artifact-id
|
||||
|Defines the `artifactId` of the model to be loaded from the zoo.
|
||||
|
||||
|djl.model-name
|
||||
|(Optional) Defines the modelName of the model to be loaded.
|
||||
Leave it empty if you want to load the latest version of the model.
|
||||
Use "saved_model" for TensorFlow saved models.
|
||||
|
||||
|djl.engine
|
||||
| Name of teh https://docs.djl.ai/docs/engine.html[Engine] to use https://docs.djl.ai/docs/engine.html#supported-engines[Supported engine names].
|
||||
|
||||
|djl.translator-factory
|
||||
| https://javadoc.io/doc/ai.djl/api/latest/ai/djl/translate/Translator.html[Translator] provides model pre-processing and postprocessing functionality. Multiple https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/cv/translator/package-summary.html[translators] are provided for different models, but you can implement your own translator if needed (see []). The translator-factory property allow to specify the translator to be used with the model.
|
||||
|
||||
|computer.vision.function.output-header-name
|
||||
|Name of the header that contains the JSON payload computed by the functions.
|
||||
|
||||
|computer.vision.function.augment-enabled
|
||||
|Enable image augmentation (false by default).
|
||||
|
||||
|===
|
||||
|
||||
Also, this function exposes its specific properties with a `computer.vision.function` prefix.
|
||||
See link:src/main/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionProperties.java[ComputerVisionFunctionProperties] for more details.
|
||||
|
||||
=== Example Configurations
|
||||
|
||||
All computer vision examples use the following Java code snippet to invoke the function:
|
||||
|
||||
[source,Java]
|
||||
----
|
||||
@SpringBootApplication
|
||||
public class TfObjectDetectionBootApp implements CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private Function<Message<byte[]>, Message<byte[]>> cvFunction;
|
||||
|
||||
@Override
|
||||
public void run(String... args) throws Exception {
|
||||
byte[] inputImage = new ClassPathResource("Image URI").getInputStream().readAllBytes();
|
||||
|
||||
Message<byte[]> outputMessage = cvFunction.apply(
|
||||
MessageBuilder.withPayload(inputImage).build());
|
||||
|
||||
// Augmented output image.
|
||||
byte[] outputImage = outputMessage.getPayload();
|
||||
|
||||
// JSON payload with the detected objects and their bounding boxes.
|
||||
String jsonBoundingBoxes = outputMessage.getHeader("cvjson", String.class);
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(TfObjectDetectionBootApp.class);
|
||||
}
|
||||
}
|
||||
----
|
||||
|
||||
==== Object Detection (TensorFlow)
|
||||
|
||||
You can leverage any of the existing TensorFlow models.
|
||||
Just comply the url of the model archive as a `djl.urls` property and set the `djl.translator-factory` to `org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory`.
|
||||
|
||||
----
|
||||
computer.vision.function.augment-enabled=true
|
||||
djl.application-type=OBJECT_DETECTION
|
||||
djl.input-class=ai.djl.modality.cv.Image
|
||||
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
|
||||
djl.engine=TensorFlow
|
||||
djl.urls=http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz
|
||||
djl.model-name=saved_model
|
||||
djl.translator-factory=org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory
|
||||
djl.arguments.threshold=0.3
|
||||
----
|
||||
|
||||
==== Object Detection (Yolo v8)
|
||||
|
||||
You can use the same Java snipped above, just change the configuration to use the Yolo v8 model:
|
||||
|
||||
----
|
||||
computer.vision.function.augment-enabled=true
|
||||
djl.application-type=OBJECT_DETECTION
|
||||
djl.input-class=ai.djl.modality.cv.Image
|
||||
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
|
||||
djl.engine=OnnxRuntime
|
||||
djl.urls=djl://ai.djl.onnxruntime/yolov8n
|
||||
djl.translator-factory=ai.djl.modality.cv.translator.YoloV8TranslatorFactory
|
||||
djl.arguments.threshold=0.3
|
||||
djl.arguments.width=640
|
||||
djl.arguments.height=640
|
||||
djl.arguments.resize=true
|
||||
djl.arguments.toTensor=true
|
||||
djl.arguments.applyRatio=true
|
||||
djl.arguments.maxBox=1000
|
||||
----
|
||||
|
||||
==== Instance Segmentation
|
||||
|
||||
Same Java code snipped but with the following configuration:
|
||||
|
||||
----
|
||||
computer.vision.function.augment-enabled=true
|
||||
djl.application-type=INSTANCE_SEGMENTATION
|
||||
djl.input-class=ai.djl.modality.cv.Image
|
||||
djl.output-class=ai.djl.modality.cv.output.DetectedObjects
|
||||
djl.arguments.threshold=0.3
|
||||
|
||||
djl.model-filter.backbone=resnet18
|
||||
djl.model-filter.flavor=v1b
|
||||
djl.model-filter.dataset=coco
|
||||
----
|
||||
|
||||
Note that here we didn't specify the model to be used, but used the model-filter to find a compatible model from the model zoo.
|
||||
|
||||
==== Semantic Segmentation
|
||||
|
||||
Same Java code snipped but with the following configuration:
|
||||
|
||||
----
|
||||
computer.vision.function.augment-enabled=true
|
||||
djl.application-type=SEMANTIC_SEGMENTATION
|
||||
djl.input-class=ai.djl.modality.cv.Image
|
||||
djl.output-class=ai.djl.modality.cv.output.CategoryMask
|
||||
djl.arguments.threshold=0.3
|
||||
|
||||
djl.urls=https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip
|
||||
djl.translator-factory=ai.djl.modality.cv.translator.SemanticSegmentationTranslatorFactory
|
||||
djl.engine=PyTorch
|
||||
----
|
||||
|
||||
==== Image Classification
|
||||
|
||||
----
|
||||
djl.application-type=IMAGE_CLASSIFICATION
|
||||
djl.input-class=ai.djl.modality.cv.Image
|
||||
djl.output-class=ai.djl.modality.Classifications
|
||||
djl.arguments.threshold=0.3
|
||||
djl.engine=MXNet
|
||||
----
|
||||
|
||||
== Tests
|
||||
|
||||
See this link:src/test/java/org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfigurationTests.java[test suite] for examples of how this function is used.
|
||||
|
||||
The link:src/test/java/org/springframework/cloud/fn/computer/vision/JsonHelperTests.java[JsonHelperTests] validates the JSON serialization and deserialization of the `ComputerVisionFunctionConfiguration` class values object classes.
|
||||
|
||||
7
function/spring-computer-vision-function/build.gradle
Normal file
7
function/spring-computer-vision-function/build.gradle
Normal file
@@ -0,0 +1,7 @@
|
||||
dependencies {
|
||||
api djlSpringAutoconfigure
|
||||
api "ai.djl.spring:djl-spring-boot-starter-tensorflow-auto:$djlSpringVersion"
|
||||
api "ai.djl.spring:djl-spring-boot-starter-pytorch-auto:$djlSpringVersion"
|
||||
api "ai.djl.spring:djl-spring-boot-starter-mxnet-auto:$djlSpringVersion"
|
||||
runtimeOnly 'ai.djl.onnxruntime:onnxruntime-engine'
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
/*
|
||||
* Copyright 2020-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.computer.vision;
|
||||
|
||||
import java.awt.image.RenderedImage;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.UncheckedIOException;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
|
||||
import ai.djl.inference.Predictor;
|
||||
import ai.djl.modality.Classifications;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.ImageFactory;
|
||||
import ai.djl.modality.cv.output.CategoryMask;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Joints;
|
||||
import ai.djl.spring.configuration.DjlAutoConfiguration;
|
||||
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.integration.support.MessageBuilder;
|
||||
import org.springframework.messaging.Message;
|
||||
|
||||
/**
|
||||
* A configuration class that provides the necessary beans for the Computer Vision
|
||||
* Function.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
@AutoConfiguration(after = DjlAutoConfiguration.class)
|
||||
@EnableConfigurationProperties(ComputerVisionFunctionProperties.class)
|
||||
public class ComputerVisionFunctionConfiguration {
|
||||
|
||||
private static final ImageFactory IMAGE_FACTORY = ImageFactory.getInstance();
|
||||
|
||||
private final Supplier<Predictor<?, ?>> predictorProvider;
|
||||
|
||||
private final ComputerVisionFunctionProperties cvProperties;
|
||||
|
||||
public ComputerVisionFunctionConfiguration(Supplier<Predictor<?, ?>> predictorProvider,
|
||||
ComputerVisionFunctionProperties cvProperties) {
|
||||
|
||||
this.predictorProvider = predictorProvider;
|
||||
this.cvProperties = cvProperties;
|
||||
}
|
||||
|
||||
@Bean(name = "computerVisionFunction")
|
||||
@ConditionalOnProperty(prefix = "djl", name = "output-class",
|
||||
havingValue = "ai.djl.modality.cv.output.DetectedObjects")
|
||||
public Function<Message<byte[]>, Message<byte[]>> objectDetection() {
|
||||
|
||||
BiFunction<DetectedObjects, Image, byte[]> augmentImage = (detectedObjects, image) -> {
|
||||
Image newImage = image.duplicate();
|
||||
newImage.drawBoundingBoxes(detectedObjects);
|
||||
return getByteArray((RenderedImage) newImage.getWrappedImage(),
|
||||
this.cvProperties.getOutputImageFormatName());
|
||||
};
|
||||
|
||||
return predictor(JsonHelper::toJson, augmentImage);
|
||||
}
|
||||
|
||||
@Bean(name = "computerVisionFunction")
|
||||
@ConditionalOnProperty(prefix = "djl", name = "output-class",
|
||||
havingValue = "ai.djl.modality.cv.output.CategoryMask")
|
||||
public Function<Message<byte[]>, Message<byte[]>> semanticSegmentation() {
|
||||
BiFunction<CategoryMask, Image, byte[]> augmentImage = (mask, image) -> {
|
||||
Image newImage = image.duplicate();
|
||||
mask.drawMask(newImage, 200, 0);
|
||||
return getByteArray((RenderedImage) newImage.getWrappedImage(),
|
||||
this.cvProperties.getOutputImageFormatName());
|
||||
};
|
||||
|
||||
return predictor(JsonHelper::toJson, augmentImage);
|
||||
}
|
||||
|
||||
@Bean(name = "computerVisionFunction")
|
||||
@ConditionalOnProperty(prefix = "djl", name = "output-class", havingValue = "ai.djl.modality.Classifications")
|
||||
public Function<Message<byte[]>, Message<byte[]>> imageClassifications() {
|
||||
BiFunction<Classifications, Image, byte[]> augmentImage = (classifications, image) -> {
|
||||
Image newImage = image.duplicate();
|
||||
return getByteArray((RenderedImage) newImage.getWrappedImage(),
|
||||
this.cvProperties.getOutputImageFormatName());
|
||||
};
|
||||
|
||||
return predictor(JsonHelper::toJson, augmentImage);
|
||||
}
|
||||
|
||||
@Bean(name = "computerVisionFunction")
|
||||
@ConditionalOnProperty(prefix = "djl", name = "output-class", havingValue = "ai.djl.modality.cv.output.Joints")
|
||||
public Function<Message<byte[]>, Message<byte[]>> poseEstimation() {
|
||||
BiFunction<Joints, Image, byte[]> augmentImage = (joints, image) -> {
|
||||
Image newImage = image.duplicate();
|
||||
newImage.drawJoints(joints);
|
||||
return getByteArray((RenderedImage) newImage.getWrappedImage(),
|
||||
this.cvProperties.getOutputImageFormatName());
|
||||
};
|
||||
|
||||
return predictor(JsonHelper::toJson, augmentImage);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private <T> Function<Message<byte[]>, Message<byte[]>> predictor(Function<T, String> toJsonFunction,
|
||||
BiFunction<T, Image, byte[]> augmentImageFunction) {
|
||||
|
||||
return (input) -> {
|
||||
|
||||
try (Predictor<Image, T> predictor = (Predictor<Image, T>) this.predictorProvider.get()) {
|
||||
Image image = IMAGE_FACTORY.fromInputStream(new ByteArrayInputStream(input.getPayload()));
|
||||
|
||||
T output = predictor.predict(image);
|
||||
|
||||
String outputJson = toJsonFunction.apply(output);
|
||||
|
||||
byte[] outputImageBytes = input.getPayload();
|
||||
|
||||
if (this.cvProperties.isAugmentEnabled()) {
|
||||
outputImageBytes = augmentImageFunction.apply(output, image);
|
||||
}
|
||||
|
||||
String headerName = this.cvProperties.getOutputHeaderName();
|
||||
|
||||
return MessageBuilder.withPayload(outputImageBytes).setHeader(headerName, outputJson).build();
|
||||
}
|
||||
catch (Exception ex) {
|
||||
throw new IllegalStateException(ex);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private static byte[] getByteArray(RenderedImage image, String formatName) {
|
||||
try {
|
||||
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
|
||||
ImageIO.write(image, formatName, byteArrayOutputStream);
|
||||
return byteArrayOutputStream.toByteArray();
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new UncheckedIOException(ex);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright 2020-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.computer.vision;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
/**
|
||||
* Configuration properties for the Computer Vision Function.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
@ConfigurationProperties("computer.vision.function")
|
||||
public class ComputerVisionFunctionProperties {
|
||||
|
||||
/**
|
||||
* Enable image augmentation.
|
||||
*/
|
||||
private boolean augmentEnabled = false;
|
||||
|
||||
/**
|
||||
* Output augmented image format name.
|
||||
*/
|
||||
private String outputImageFormatName = "png";
|
||||
|
||||
/**
|
||||
* Name of the header that contains the JSON payload computed by the functions.
|
||||
*/
|
||||
private String outputHeaderName = "cvjson";
|
||||
|
||||
public boolean isAugmentEnabled() {
|
||||
return this.augmentEnabled;
|
||||
}
|
||||
|
||||
public void setAugmentEnabled(boolean augmentImage) {
|
||||
this.augmentEnabled = augmentImage;
|
||||
}
|
||||
|
||||
public String getOutputImageFormatName() {
|
||||
return this.outputImageFormatName;
|
||||
}
|
||||
|
||||
public void setOutputImageFormatName(String outputImageFormatName) {
|
||||
this.outputImageFormatName = outputImageFormatName;
|
||||
}
|
||||
|
||||
public String getOutputHeaderName() {
|
||||
return this.outputHeaderName;
|
||||
}
|
||||
|
||||
public void setOutputHeaderName(String jsonHeaderName) {
|
||||
this.outputHeaderName = jsonHeaderName;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright 2024-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.computer.vision;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.List;
|
||||
|
||||
import ai.djl.modality.Classifications;
|
||||
import ai.djl.modality.cv.output.BoundingBox;
|
||||
import ai.djl.modality.cv.output.CategoryMask;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Joints;
|
||||
import ai.djl.modality.cv.output.Rectangle;
|
||||
import ai.djl.util.JsonUtils;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.JsonDeserializationContext;
|
||||
import com.google.gson.JsonDeserializer;
|
||||
import com.google.gson.JsonElement;
|
||||
import com.google.gson.JsonParseException;
|
||||
import com.google.gson.JsonSerializationContext;
|
||||
import com.google.gson.JsonSerializer;
|
||||
|
||||
/**
|
||||
* Helper class to serialize and deserialize {@link DetectedObjects},
|
||||
* {@link Classifications}, {@link CategoryMask} and {@link Joints} to/from JSON.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public final class JsonHelper {
|
||||
|
||||
private static final Gson GSON = JsonUtils.builder().create();
|
||||
|
||||
private JsonHelper() {
|
||||
}
|
||||
|
||||
public static String toJson(Joints joints) {
|
||||
return GSON.toJson(joints);
|
||||
}
|
||||
|
||||
public static Joints toJoints(String json) {
|
||||
return GSON.fromJson(json, Joints.class);
|
||||
}
|
||||
|
||||
public static String toJson(CategoryMask categoryMask) {
|
||||
return GSON.toJson(Mask.fromCategoryMask(categoryMask));
|
||||
}
|
||||
|
||||
public static CategoryMask toCategoryMask(String json) {
|
||||
return GSON.fromJson(json, Mask.class).toCategoryMask();
|
||||
}
|
||||
|
||||
public static String toJson(Classifications classifications) {
|
||||
return GSON.toJson(classifications);
|
||||
}
|
||||
|
||||
public static Classifications toClassifications(String json) {
|
||||
return GSON.fromJson(json, Classifications.class);
|
||||
}
|
||||
|
||||
private static final Gson GSON2 = JsonUtils.builder()
|
||||
.registerTypeAdapter(BoundingBox.class, new BoundingBoxAdapter())
|
||||
.create();
|
||||
|
||||
public static String toJson(DetectedObjects detectedObject) {
|
||||
return GSON2.toJson(detectedObject);
|
||||
}
|
||||
|
||||
public static DetectedObjects toDetectedObjects(String json) {
|
||||
return GSON2.fromJson(json, DetectedObjects.class);
|
||||
}
|
||||
|
||||
public record Mask(List<String> classes, int[][] mask) {
|
||||
|
||||
public static Mask fromCategoryMask(CategoryMask categoryMask) {
|
||||
return new Mask(categoryMask.getClasses(), categoryMask.getMask());
|
||||
}
|
||||
|
||||
public CategoryMask toCategoryMask() {
|
||||
return new CategoryMask(this.classes, this.mask);
|
||||
}
|
||||
}
|
||||
|
||||
public static class BoundingBoxAdapter implements JsonSerializer<BoundingBox>, JsonDeserializer<BoundingBox> {
|
||||
|
||||
@Override
|
||||
public JsonElement serialize(BoundingBox boundingBox, Type typeOfSrc, JsonSerializationContext context) {
|
||||
return context.serialize(boundingBox);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BoundingBox deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context)
|
||||
throws JsonParseException {
|
||||
return context.deserialize(json, Rectangle.class);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
/**
|
||||
* Provides classes for the Computer Vision Function.
|
||||
*/
|
||||
package org.springframework.cloud.fn.computer.vision;
|
||||
@@ -0,0 +1,184 @@
|
||||
/*
|
||||
* Copyright 2020-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.computer.vision.translator;
|
||||
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Scanner;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.output.BoundingBox;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Rectangle;
|
||||
import ai.djl.modality.cv.util.NDImageUtils;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDList;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.translate.NoBatchifyTranslator;
|
||||
import ai.djl.translate.TranslatorContext;
|
||||
import ai.djl.util.JsonUtils;
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
|
||||
/**
|
||||
* A {@link NoBatchifyTranslator} that post-processes the output of a TensorFlow
|
||||
* SavedModel Object Detection model.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public final class TensorflowSavedModelObjectDetectionTranslator
|
||||
implements NoBatchifyTranslator<Image, DetectedObjects> {
|
||||
|
||||
private static final String ITEM_DELIMITER = "item ";
|
||||
|
||||
private static final String DEFAULT_MSCOCO_LABELS_URL = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
|
||||
|
||||
private static final String DETECTION_BOXES = "detection_boxes";
|
||||
|
||||
private static final String DETECTION_SCORES = "detection_scores";
|
||||
|
||||
private static final String DETECTION_CLASSES = "detection_classes";
|
||||
|
||||
private String classLabelsUrl;
|
||||
|
||||
private Map<Integer, String> classLabels;
|
||||
|
||||
private int maxBoxes;
|
||||
|
||||
private float threshold;
|
||||
|
||||
public TensorflowSavedModelObjectDetectionTranslator() {
|
||||
this(DEFAULT_MSCOCO_LABELS_URL, 10, 0.3f);
|
||||
}
|
||||
|
||||
public TensorflowSavedModelObjectDetectionTranslator(String categoryLabelsUrl, int maxBoxes, float threshold) {
|
||||
this.classLabelsUrl = categoryLabelsUrl;
|
||||
this.maxBoxes = maxBoxes;
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public NDList processInput(TranslatorContext ctx, Image input) {
|
||||
// input to tf object-detection models is a list of tensors, hence NDList
|
||||
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
|
||||
// optionally resize the image for faster processing
|
||||
array = NDImageUtils.resize(array, 224);
|
||||
// tf object-detection models expect 8 bit unsigned integer tensor
|
||||
array = array.toType(DataType.UINT8, true);
|
||||
// tf object-detection models expect a 4 dimensional input
|
||||
array = array.expandDims(0);
|
||||
|
||||
return new NDList(array);
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public void prepare(TranslatorContext ctx) throws IOException {
|
||||
if (this.classLabels == null) {
|
||||
this.classLabels = loadSynset();
|
||||
}
|
||||
}
|
||||
|
||||
private Map<Integer, String> loadSynset() throws IOException {
|
||||
Map<Integer, String> map = new ConcurrentHashMap<>();
|
||||
int maxId = 0;
|
||||
try (InputStream is = new BufferedInputStream(new URL(this.classLabelsUrl).openStream());
|
||||
Scanner scanner = new Scanner(is, StandardCharsets.UTF_8)) {
|
||||
|
||||
scanner.useDelimiter(ITEM_DELIMITER);
|
||||
while (scanner.hasNext()) {
|
||||
String content = scanner.next();
|
||||
content = content.replaceAll("(\"|\\d)\\n\\s", "$1,");
|
||||
Item item = JsonUtils.GSON.fromJson(content, Item.class);
|
||||
map.put(item.id, item.displayName);
|
||||
if (item.id > maxId) {
|
||||
maxId = item.id;
|
||||
}
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
/** {@inheritDoc} */
|
||||
@Override
|
||||
public DetectedObjects processOutput(TranslatorContext ctx, NDList list) {
|
||||
// output of tf object-detection models is a list of tensors, hence NDList in djl
|
||||
// output NDArray order in the list are not guaranteed
|
||||
|
||||
int[] classIds = null;
|
||||
float[] probabilities = null;
|
||||
NDArray boundingBoxes = null;
|
||||
for (NDArray array : list) {
|
||||
if (DETECTION_BOXES.equals(array.getName())) {
|
||||
boundingBoxes = array.get(0);
|
||||
}
|
||||
else if (DETECTION_SCORES.equals(array.getName())) {
|
||||
probabilities = array.get(0).toFloatArray();
|
||||
}
|
||||
else if (DETECTION_CLASSES.equals(array.getName())) {
|
||||
// class id is between 1 - number of classes
|
||||
classIds = array.get(0).toType(DataType.INT32, true).toIntArray();
|
||||
}
|
||||
}
|
||||
Objects.requireNonNull(classIds);
|
||||
Objects.requireNonNull(probabilities);
|
||||
Objects.requireNonNull(boundingBoxes);
|
||||
|
||||
List<String> retNames = new ArrayList<>();
|
||||
List<Double> retProbs = new ArrayList<>();
|
||||
List<BoundingBox> retBB = new ArrayList<>();
|
||||
|
||||
// result are already sorted
|
||||
for (int i = 0; i < Math.min(classIds.length, this.maxBoxes); ++i) {
|
||||
int classId = classIds[i];
|
||||
double probability = probabilities[i];
|
||||
// classId starts from 1, -1 means background
|
||||
if (classId > 0 && probability > this.threshold) {
|
||||
String className = this.classLabels.getOrDefault(classId, "#" + classId);
|
||||
float[] box = boundingBoxes.get(i).toFloatArray();
|
||||
float yMin = box[0];
|
||||
float xMin = box[1];
|
||||
float yMax = box[2];
|
||||
float xMax = box[3];
|
||||
Rectangle rect = new Rectangle(xMin, yMin, xMax - xMin, yMax - yMin);
|
||||
retNames.add(className);
|
||||
retProbs.add(probability);
|
||||
retBB.add(rect);
|
||||
}
|
||||
}
|
||||
|
||||
return new DetectedObjects(retNames, retProbs, retBB);
|
||||
}
|
||||
|
||||
private static final class Item {
|
||||
|
||||
int id;
|
||||
|
||||
@SerializedName("display_name")
|
||||
String displayName;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright 2024-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.computer.vision.translator;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import ai.djl.Model;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.translator.ObjectDetectionTranslatorFactory;
|
||||
import ai.djl.translate.Translator;
|
||||
|
||||
/**
|
||||
* Translator for TensorFlow Object Detection SavedModel.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class TensorflowSavedModelObjectDetectionTranslatorFactory extends ObjectDetectionTranslatorFactory {
|
||||
|
||||
@Override
|
||||
protected Translator<Image, DetectedObjects> buildBaseTranslator(Model model, Map<String, ?> arguments) {
|
||||
return new TensorflowSavedModelObjectDetectionTranslator();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
/**
|
||||
* Provides classes for translating the output of the computer vision function.
|
||||
*/
|
||||
package org.springframework.cloud.fn.computer.vision.translator;
|
||||
@@ -0,0 +1 @@
|
||||
org.springframework.cloud.fn.computer.vision.ComputerVisionFunctionConfiguration
|
||||
@@ -0,0 +1,331 @@
|
||||
/*
|
||||
* Copyright 2020-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.computer.vision;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
import ai.djl.modality.Classifications;
|
||||
import ai.djl.modality.cv.Image;
|
||||
import ai.djl.modality.cv.output.CategoryMask;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Joints;
|
||||
import ai.djl.modality.cv.translator.SemanticSegmentationTranslatorFactory;
|
||||
import ai.djl.modality.cv.translator.YoloV8TranslatorFactory;
|
||||
import ai.djl.repository.zoo.ZooModel;
|
||||
import ai.djl.spring.configuration.ApplicationType;
|
||||
import ai.djl.spring.configuration.DjlAutoConfiguration;
|
||||
import ai.djl.spring.configuration.DjlConfigurationProperties;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.DisabledOnOs;
|
||||
import org.junit.jupiter.api.condition.OS;
|
||||
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.cloud.fn.computer.vision.translator.TensorflowSavedModelObjectDetectionTranslatorFactory;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.messaging.Message;
|
||||
import org.springframework.messaging.support.MessageBuilder;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class ComputerVisionFunctionConfigurationTests {
|
||||
|
||||
private ApplicationContextRunner applicationContextRunner;
|
||||
|
||||
@BeforeEach
|
||||
public void setUp() {
|
||||
applicationContextRunner = new ApplicationContextRunner().withConfiguration(
|
||||
AutoConfigurations.of(DjlAutoConfiguration.class, ComputerVisionFunctionConfiguration.class));
|
||||
}
|
||||
|
||||
/**
|
||||
* This configuration can be used to load any of the Tensorflow2 models for object
|
||||
* detection from here:
|
||||
* https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md
|
||||
*/
|
||||
@Test
|
||||
public void tf2SavedModel() {
|
||||
applicationContextRunner.withPropertyValues(
|
||||
// @formatter:off
|
||||
"computer.vision.function.augment-enabled=true",
|
||||
"djl.application-type=" + ApplicationType.OBJECT_DETECTION,
|
||||
"djl.input-class=" + Image.class.getName(),
|
||||
"djl.output-class=" + DetectedObjects.class.getName(),
|
||||
"djl.engine=TensorFlow",
|
||||
"djl.urls=http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz",
|
||||
"djl.model-name=saved_model",
|
||||
"djl.translator-factory=" + TensorflowSavedModelObjectDetectionTranslatorFactory.class.getName(),
|
||||
"djl.arguments.threshold=0.3")
|
||||
// @formatter:on
|
||||
.run((context) -> {
|
||||
assertThat(context).hasSingleBean(ZooModel.class);
|
||||
assertThat(context).hasBean("predictorProvider");
|
||||
@SuppressWarnings("unchecked")
|
||||
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
|
||||
.getBean("computerVisionFunction");
|
||||
|
||||
var djlProperties = context.getBean(DjlConfigurationProperties.class);
|
||||
|
||||
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.OBJECT_DETECTION);
|
||||
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
|
||||
assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class);
|
||||
assertThat(djlProperties.getEngine()).isEqualTo("TensorFlow");
|
||||
assertThat(djlProperties.getUrls()).contains(
|
||||
"http://download.tensorflow.org/models/object_detection/tf2/20200711/faster_rcnn_inception_resnet_v2_1024x1024_coco17_tpu-8.tar.gz");
|
||||
assertThat(djlProperties.getModelName()).isEqualTo("saved_model");
|
||||
assertThat(djlProperties.getTranslatorFactory())
|
||||
.isEqualTo(TensorflowSavedModelObjectDetectionTranslatorFactory.class.getName());
|
||||
|
||||
byte[] inputImage = new ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes();
|
||||
|
||||
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
|
||||
|
||||
assertThat(outputMessage).isNotNull();
|
||||
assertThat(outputMessage.getPayload()).isNotNull();
|
||||
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
|
||||
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
|
||||
|
||||
var json = outputMessage.getHeaders().get("cvjson", String.class);
|
||||
|
||||
assertThat(JsonHelper.toDetectedObjects(json)).isNotNull();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void yolov8Detection() {
|
||||
applicationContextRunner.withPropertyValues(
|
||||
// @formatter:off
|
||||
"computer.vision.function.augment-enabled=true",
|
||||
"djl.application-type=" + ApplicationType.OBJECT_DETECTION,
|
||||
"djl.input-class=" + Image.class.getName(),
|
||||
"djl.output-class=" + DetectedObjects.class.getName(),
|
||||
"djl.engine=OnnxRuntime",
|
||||
"djl.urls=djl://ai.djl.onnxruntime/yolov8n",
|
||||
"djl.translator-factory=" + YoloV8TranslatorFactory.class.getName(),
|
||||
"djl.arguments.threshold=0.3",
|
||||
"djl.arguments.width=640",
|
||||
"djl.arguments.height=640",
|
||||
"djl.arguments.resize=true",
|
||||
"djl.arguments.toTensor=true",
|
||||
"djl.arguments.applyRatio=true",
|
||||
"djl.arguments.maxBox=1000")
|
||||
// @formatter:on
|
||||
.run((context) -> {
|
||||
assertThat(context).hasSingleBean(ZooModel.class);
|
||||
assertThat(context).hasBean("predictorProvider");
|
||||
@SuppressWarnings("unchecked")
|
||||
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
|
||||
.getBean("computerVisionFunction");
|
||||
|
||||
var djlProperties = context.getBean(DjlConfigurationProperties.class);
|
||||
|
||||
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.OBJECT_DETECTION);
|
||||
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
|
||||
assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class);
|
||||
assertThat(djlProperties.getEngine()).isEqualTo("OnnxRuntime");
|
||||
assertThat(djlProperties.getUrls()).contains("djl://ai.djl.onnxruntime/yolov8n");
|
||||
assertThat(djlProperties.getTranslatorFactory()).isEqualTo(YoloV8TranslatorFactory.class.getName());
|
||||
|
||||
byte[] inputImage = new ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes();
|
||||
|
||||
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
|
||||
|
||||
assertThat(outputMessage).isNotNull();
|
||||
assertThat(outputMessage.getPayload()).isNotNull();
|
||||
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
|
||||
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
|
||||
|
||||
var json = outputMessage.getHeaders().get("cvjson", String.class);
|
||||
|
||||
var detectionObjects = JsonHelper.toDetectedObjects(json);
|
||||
|
||||
assertThat(detectionObjects).isNotNull();
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void instanceSegmentation() {
|
||||
applicationContextRunner.withPropertyValues(
|
||||
// @formatter:off
|
||||
"computer.vision.function.augment-enabled=true",
|
||||
"djl.application-type=" + ApplicationType.INSTANCE_SEGMENTATION,
|
||||
"djl.input-class=" + Image.class.getName(),
|
||||
"djl.output-class=" + DetectedObjects.class.getName(),
|
||||
"djl.arguments.threshold=0.3",
|
||||
|
||||
"djl.model-filter.backbone=resnet18",
|
||||
"djl.model-filter.flavor=v1b",
|
||||
"djl.model-filter.dataset=coco")
|
||||
// @formatter:on
|
||||
.run((context) -> {
|
||||
assertThat(context).hasSingleBean(ZooModel.class);
|
||||
assertThat(context).hasBean("predictorProvider");
|
||||
@SuppressWarnings("unchecked")
|
||||
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
|
||||
.getBean("computerVisionFunction");
|
||||
|
||||
var djlProperties = context.getBean(DjlConfigurationProperties.class);
|
||||
|
||||
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.INSTANCE_SEGMENTATION);
|
||||
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
|
||||
assertThat(djlProperties.getOutputClass()).isEqualTo(DetectedObjects.class);
|
||||
|
||||
// byte[] inputImage = new
|
||||
// ClassPathResource("/object-detection.jpg").getInputStream().readAllBytes();
|
||||
byte[] inputImage = new ClassPathResource("/amsterdam-cityscape.jpg").getInputStream().readAllBytes();
|
||||
|
||||
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
|
||||
|
||||
assertThat(outputMessage).isNotNull();
|
||||
assertThat(outputMessage.getPayload()).isNotNull();
|
||||
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
|
||||
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
|
||||
String json = outputMessage.getHeaders().get("cvjson", String.class);
|
||||
|
||||
assertThat(JsonHelper.toDetectedObjects(json)).isNotNull();
|
||||
});
|
||||
}
|
||||
|
||||
@DisabledOnOs(OS.WINDOWS)
|
||||
@Test
|
||||
public void semanticSegmentation() {
|
||||
applicationContextRunner.withPropertyValues(
|
||||
// @formatter:off
|
||||
"computer.vision.function.augment-enabled=true",
|
||||
"djl.application-type=" + ApplicationType.SEMANTIC_SEGMENTATION,
|
||||
"djl.input-class=" + Image.class.getName(),
|
||||
"djl.output-class=" + CategoryMask.class.getName(),
|
||||
"djl.arguments.threshold=0.3",
|
||||
|
||||
"djl.urls=https://mlrepo.djl.ai/model/cv/semantic_segmentation/ai/djl/pytorch/deeplabv3/0.0.1/deeplabv3.zip",
|
||||
"djl.translator-factory=" + SemanticSegmentationTranslatorFactory.class.getName(),
|
||||
"djl.engine=PyTorch")
|
||||
// @formatter:on
|
||||
.run((context) -> {
|
||||
assertThat(context).hasSingleBean(ZooModel.class);
|
||||
assertThat(context).hasBean("predictorProvider");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
|
||||
.getBean("computerVisionFunction");
|
||||
|
||||
var djlProperties = context.getBean(DjlConfigurationProperties.class);
|
||||
|
||||
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.SEMANTIC_SEGMENTATION);
|
||||
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
|
||||
assertThat(djlProperties.getOutputClass()).isEqualTo(CategoryMask.class);
|
||||
|
||||
byte[] inputImage = new ClassPathResource("/amsterdam-cityscape.jpg").getInputStream().readAllBytes();
|
||||
|
||||
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
|
||||
|
||||
assertThat(outputMessage).isNotNull();
|
||||
assertThat(outputMessage.getPayload()).isNotNull();
|
||||
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
|
||||
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
|
||||
|
||||
String ssJson = outputMessage.getHeaders().get("cvjson", String.class);
|
||||
|
||||
assertThat(JsonHelper.toCategoryMask(ssJson)).isNotNull();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void imageClassifications() {
|
||||
applicationContextRunner.withPropertyValues(
|
||||
// @formatter:off
|
||||
"computer.vision.function.augment-enabled=false",
|
||||
"djl.application-type=" + ApplicationType.IMAGE_CLASSIFICATION,
|
||||
"djl.input-class=" + Image.class.getName(),
|
||||
"djl.output-class=" + Classifications.class.getName(),
|
||||
"djl.arguments.threshold=0.3",
|
||||
"djl.engine=MXNet")
|
||||
// @formatter:on
|
||||
.run((context) -> {
|
||||
assertThat(context).hasSingleBean(ZooModel.class);
|
||||
assertThat(context).hasBean("predictorProvider");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
|
||||
.getBean("computerVisionFunction");
|
||||
|
||||
var djlProperties = context.getBean(DjlConfigurationProperties.class);
|
||||
|
||||
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.IMAGE_CLASSIFICATION);
|
||||
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
|
||||
assertThat(djlProperties.getOutputClass()).isEqualTo(Classifications.class);
|
||||
|
||||
byte[] inputImage = new ClassPathResource("/karakatschan.jpg").getInputStream().readAllBytes();
|
||||
|
||||
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
|
||||
|
||||
assertThat(outputMessage).isNotNull();
|
||||
assertThat(outputMessage.getPayload()).isNotNull();
|
||||
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
|
||||
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
|
||||
|
||||
String json = outputMessage.getHeaders().get("cvjson", String.class);
|
||||
|
||||
assertThat(JsonHelper.toClassifications(json)).isNotNull();
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void poseEstimation() {
|
||||
applicationContextRunner.withPropertyValues(
|
||||
// @formatter:off
|
||||
"computer.vision.function.augment-enabled=true",
|
||||
"djl.application-type=" + ApplicationType.POSE_ESTIMATION,
|
||||
"djl.input-class=" + Image.class.getName(),
|
||||
"djl.output-class=" + Joints.class.getName(),
|
||||
"djl.arguments.threshold=0.3",
|
||||
"djl.model-filter.backbone=resnet18",
|
||||
"djl.model-filter.flavor=v1b",
|
||||
"djl.model-filter.dataset=imagenet")
|
||||
// @formatter:on
|
||||
.run((context) -> {
|
||||
assertThat(context).hasSingleBean(ZooModel.class);
|
||||
assertThat(context).hasBean("predictorProvider");
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Function<Message<byte[]>, Message<byte[]>> predictor = (Function<Message<byte[]>, Message<byte[]>>) context
|
||||
.getBean("computerVisionFunction");
|
||||
|
||||
var djlProperties = context.getBean(DjlConfigurationProperties.class);
|
||||
|
||||
assertThat(djlProperties.getApplicationType()).isEqualTo(ApplicationType.POSE_ESTIMATION);
|
||||
assertThat(djlProperties.getInputClass()).isEqualTo(Image.class);
|
||||
assertThat(djlProperties.getOutputClass()).isEqualTo(Joints.class);
|
||||
|
||||
byte[] inputImage = new ClassPathResource("/pose.png").getInputStream().readAllBytes();
|
||||
|
||||
Message<byte[]> outputMessage = predictor.apply(MessageBuilder.withPayload(inputImage).build());
|
||||
|
||||
assertThat(outputMessage).isNotNull();
|
||||
assertThat(outputMessage.getPayload()).isNotNull();
|
||||
assertThat(outputMessage.getPayload().length).isGreaterThan(0);
|
||||
assertThat(outputMessage.getHeaders()).containsKey("cvjson");
|
||||
|
||||
String ssJson = outputMessage.getHeaders().get("cvjson", String.class);
|
||||
|
||||
assertThat(JsonHelper.toJoints(ssJson)).isNotNull();
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Copyright 2024-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.computer.vision;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import ai.djl.modality.Classifications;
|
||||
import ai.djl.modality.cv.output.CategoryMask;
|
||||
import ai.djl.modality.cv.output.DetectedObjects;
|
||||
import ai.djl.modality.cv.output.Rectangle;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class JsonHelperTests {
|
||||
|
||||
@Test
|
||||
public void categoryMask() {
|
||||
|
||||
var categoryMask = new CategoryMask(List.of("a", "b", "c"), new int[][] { { 1, 2, 3 }, { 4, 5, 6 } });
|
||||
|
||||
var json = JsonHelper.toJson(categoryMask);
|
||||
|
||||
assertThat(json).isNotEmpty();
|
||||
|
||||
var categoryMask2 = JsonHelper.toCategoryMask(json);
|
||||
|
||||
assertThat(categoryMask.getClasses()).isEqualTo(categoryMask2.getClasses());
|
||||
assertThat(categoryMask.getMask()).isEqualTo(categoryMask2.getMask());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classifications() {
|
||||
|
||||
var classifications = new Classifications(List.of("a", "b", "c"), List.of(0.1, 0.2, 0.3));
|
||||
classifications.setTopK(3);
|
||||
|
||||
var json = JsonHelper.toJson(classifications);
|
||||
|
||||
assertThat(json).isNotEmpty();
|
||||
|
||||
var classifications2 = JsonHelper.toClassifications(json);
|
||||
|
||||
assertThat(classifications2.getClassNames()).isEqualTo(classifications.getClassNames());
|
||||
assertThat(classifications2.getProbabilities()).isEqualTo(classifications.getProbabilities());
|
||||
assertThat(classifications2.topK()).hasSize(3);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detectedObjects() {
|
||||
DetectedObjects detectedObjects = new DetectedObjects(List.of("a", "b", "c"), List.of(0.1, 0.2, 0.3),
|
||||
List.of(new Rectangle(1, 2, 3, 4), new Rectangle(5, 6, 7, 8), new Rectangle(9, 10, 11, 12)));
|
||||
detectedObjects.setTopK(3);
|
||||
|
||||
var json = JsonHelper.toJson(detectedObjects);
|
||||
|
||||
assertThat(json).isNotEmpty();
|
||||
|
||||
var detectedObjects2 = JsonHelper.toDetectedObjects(json);
|
||||
|
||||
assertThat(detectedObjects2.getClassNames()).isEqualTo(detectedObjects.getClassNames());
|
||||
assertThat(detectedObjects2.getProbabilities()).isEqualTo(detectedObjects.getProbabilities());
|
||||
assertThat(detectedObjects2.topK()).hasSize(3);
|
||||
|
||||
assertThat(detectedObjects2.getNumberOfObjects()).isEqualTo(3);
|
||||
}
|
||||
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 102 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 1013 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 1.2 MiB |
Reference in New Issue
Block a user