Tensorflow functions and applications

* initial step
 * Tensorflow models functional model redesign

      -- Based on https://tzolov.github.io/mind-model-services
      -- Resolves #5

 * Add object detection processor README
 * Add image recognition processor README
 * Initial Tensorflow commonn README
 * Initial Tensorflow commonn README
 * Tensorflow common diagram
 * Tensorflow docs code
 * Tensorflow docs code snippets improve
 * Tensorflow docs code snippets improve
 * Tensorflow docs code snippets improve
 * Tensorflow docs code snippets improve
 * Add semantic segmentation function. add object detecteion function readme
 * oo images
 * Furether oo readme improvments
 * Final obj detection readme fixes
 * Add image recognition readme
 * Add image recognition readme 2
 * Semantic segmentation readme
 * Segmentation readme
 * Semantic segmentation readme 3
 * Fix image recognition and object detcion app starter dependecies

 * Add metadata for Tensorflow apps
This commit is contained in:
Christian Tzolov
2020-06-11 17:07:44 +02:00
committed by Soby Chacko
parent 3bb9e066b9
commit dffb467da4
66 changed files with 10300 additions and 1 deletions

View File

@@ -0,0 +1,119 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.List;
import java.util.function.BiFunction;
import javax.imageio.ImageIO;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.util.CollectionUtils;
/**
* Augment the input image fromMemory detected object bounding boxes and categories.
* For mask models and withMask set to true it draws the instance segmentation image as well.
*
* @author Christian Tzolov
*/
public class ObjectDetectionImageAugmenter implements BiFunction<byte[], List<ObjectDetection>, byte[]> {
private static final Log logger = LogFactory.getLog(ObjectDetectionImageAugmenter.class);
/** Make checkstyle happy. **/
public static final String DEFAULT_IMAGE_FORMAT = "jpg";
private String imageFormat = DEFAULT_IMAGE_FORMAT;
private final boolean withMask;
private boolean agnosticColors = false;
public ObjectDetectionImageAugmenter() {
this(false);
}
public ObjectDetectionImageAugmenter(boolean withMask) {
this.withMask = withMask;
}
public boolean isAgnosticColors() {
return agnosticColors;
}
public void setAgnosticColors(boolean agnosticColors) {
this.agnosticColors = agnosticColors;
}
public String getImageFormat() {
return imageFormat;
}
public void setImageFormat(String imageFormat) {
this.imageFormat = imageFormat;
}
@Override
public byte[] apply(byte[] imageBytes, List<ObjectDetection> objectDetections) {
if (!CollectionUtils.isEmpty(objectDetections)) {
try {
BufferedImage bufferedImage = ImageIO.read(new ByteArrayInputStream(imageBytes));
for (ObjectDetection od : objectDetections) {
int y1 = (int) (od.getY1() * (float) bufferedImage.getHeight());
int x1 = (int) (od.getX1() * (float) bufferedImage.getWidth());
int y2 = (int) (od.getY2() * (float) bufferedImage.getHeight());
int x2 = (int) (od.getX2() * (float) bufferedImage.getWidth());
int cid = od.getCid();
String labelName = od.getName();
int probability = (int) (100 * od.getConfidence());
String title = labelName + ": " + probability + "%";
GraphicsUtils.drawBoundingBox(bufferedImage, cid, title, x1, y1, x2, y2, this.agnosticColors);
if (this.withMask && od.getMask() != null) {
float[][] mask = od.getMask();
if (mask != null) {
Color maskColor = this.agnosticColors ? null : GraphicsUtils.getClassColor(cid);
BufferedImage maskImage = GraphicsUtils.createMaskImage(
mask, x2 - x1, y2 - y1, maskColor);
GraphicsUtils.overlayImages(bufferedImage, maskImage, x1, y1);
}
}
}
imageBytes = GraphicsUtils.toImageByteArray(bufferedImage, this.getImageFormat());
}
catch (IOException e) {
logger.error(e);
}
}
// Null mend that QR image is found and not output message will be send.
return imageBytes;
}
}

View File

@@ -0,0 +1,76 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.util.Collections;
import java.util.Map;
import java.util.function.Function;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.image.DecodeJpeg;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
/**
* Converts byte array image into a input Tensor for the Object Detection API.
*
* @author Christian Tzolov
*/
public class ObjectDetectionInputAdapter implements Function<byte[], Map<String, Tensor<?>>>, AutoCloseable {
private static final Log logger = LogFactory.getLog(ObjectDetectionInputAdapter.class);
/** Make checkstyle happy. **/
public static final String RAW_IMAGE = "raw_image";
/** Make checkstyle happy. **/
public static final String NORMALIZED_IMAGE = "normalized_image";
/** Make checkstyle happy. **/
public static final long CHANNELS = 3;
private final GraphRunner imageLoaderGraph;
public ObjectDetectionInputAdapter() {
this.imageLoaderGraph = new GraphRunner(RAW_IMAGE, NORMALIZED_IMAGE)
.withGraphDefinition(tf -> {
Placeholder<String> rawImage = tf.withName(RAW_IMAGE).placeholder(String.class);
Operand<UInt8> decodedImage = tf.dtypes.cast(
tf.image.decodeJpeg(rawImage, DecodeJpeg.channels(CHANNELS)), UInt8.class);
// Expand dimensions since the model expects images to have shape: [1, H, W, 3]
tf.withName(NORMALIZED_IMAGE).expandDims(decodedImage, tf.constant(0));
});
}
@Override
public Map<String, Tensor<?>> apply(byte[] inputImage) {
try (Tensor inputTensor = Tensor.create(inputImage)) {
return this.imageLoaderGraph.apply(Collections.singletonMap(RAW_IMAGE, inputTensor));
}
}
@Override
public void close() {
if (this.imageLoaderGraph != null) {
this.imageLoaderGraph.close();
}
}
}

View File

@@ -0,0 +1,98 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collections;
import java.util.Map;
import java.util.function.Function;
import javax.imageio.ImageIO;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
/**
* Converts byte array image into a input Tensor for the Object Detection API. The computed image tensors uses the
* 'image_tensor' model placeholder.
*
* @author Christian Tzolov
*/
public class ObjectDetectionInputConverter implements Function<byte[][], Map<String, Tensor<?>>> {
private static final Log logger = LogFactory.getLog(ObjectDetectionInputConverter.class);
private static final long CHANNELS = 3;
/** Make checkstyle happy. **/
public static final String IMAGE_TENSOR_FEED_NAME = "image_tensor";
@Override
public Map<String, Tensor<?>> apply(byte[][] input) {
return Collections.singletonMap(IMAGE_TENSOR_FEED_NAME, makeImageTensor(input));
}
private static Tensor<UInt8> makeImageTensor(byte[][] imageBytesArray) {
try {
int batchSize = imageBytesArray.length;
ByteBuffer byteBuffer = null;
long[] shape = null;
for (int batchIndex = 0; batchIndex < batchSize; batchIndex++) {
byte[] imageBytes = imageBytesArray[batchIndex];
ByteArrayInputStream is = new ByteArrayInputStream(imageBytes);
BufferedImage img = ImageIO.read(is);
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
img = GraphicsUtils.toBufferedImageType(img, BufferedImage.TYPE_3BYTE_BGR);
}
if (byteBuffer == null) {
byteBuffer = ByteBuffer.allocate((int) (batchSize * img.getHeight() * img.getWidth() * CHANNELS));
shape = new long[] { batchSize, img.getHeight(), img.getWidth(), CHANNELS };
}
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
// ImageIO.read produces BGR-encoded images, while the model expects RGB.
bgrToRgb(data);
byteBuffer.put(data);
}
byteBuffer.flip();
return Tensor.create(UInt8.class, shape, byteBuffer);
}
catch (IOException e) {
throw new IllegalArgumentException("Incorrect image format", e);
}
}
private static void bgrToRgb(byte[] data) {
for (int i = 0; i < data.length; i += 3) {
byte tmp = data[i];
data[i] = data[i + 2];
data[i + 2] = tmp;
}
}
}

View File

@@ -0,0 +1,187 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import com.google.protobuf.TextFormat;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tensorflow.Tensor;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.cloud.fn.object.detection.protos.StringIntLabelMapOuterClass;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
/**
* Converts the Tensorflow Object Detection result into {@link ObjectDetection} list.
* The pre-trained Object Detection models (http://bit.ly/2osxMAY) produce 3 tensor outputs:
* (1) detection_classes - containing the ids of detected objects, (2) detection_scores - confidence probabilities of the
* detected object and (3) detection_boxes - the object bounding boxes withing the images.
*
* The MASK based models provide to 2 additional tensors: (4) num_detections and (5) detection_masks.
*
* All outputs tensors are float arrays, having:
* - 1 as the first dimension
* - maxObjects as the second dimension
* While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
* This can be verified by looking at scoresT.shape() etc.
*
* The format detected classes (e.g. labels) names is defined by the 'string_int_labels_map.proto'. The input list
* is available at: https://github.com/tensorflow/models/tree/master/research/object_detection/data
*
* @author Christian Tzolov
*/
public class ObjectDetectionOutputConverter implements Function<Map<String, Tensor<?>>, List<List<ObjectDetection>>> {
private static final Log logger = LogFactory.getLog(ObjectDetectionOutputConverter.class);
/** DETECTION_CLASSES. */
public static final String DETECTION_CLASSES = "detection_classes";
/** DETECTION_SCORES. */
public static final String DETECTION_SCORES = "detection_scores";
/** DETECTION_BOXES. */
public static final String DETECTION_BOXES = "detection_boxes";
/** DETECTION_MASKS. */
public static final String DETECTION_MASKS = "detection_masks";
/** NUM_DETECTIONS. */
public static final String NUM_DETECTIONS = "num_detections";
private final String[] labels;
private float confidence;
private List<String> modelFetch;
public ObjectDetectionOutputConverter(Resource labelsResource, float confidence, List<String> modelFetch) {
this.confidence = confidence;
this.modelFetch = modelFetch;
try {
this.labels = loadLabels(labelsResource);
Assert.notNull(this.labels, String.format("Failed to initialize object labels [%s].", labelsResource));
}
catch (Exception e) {
throw new RuntimeException(String.format("Failed to initialize object labels [%s].", labelsResource), e);
}
logger.info(String.format("Object labels [%s] loaded.", labelsResource));
}
/**
* Loads object labels in the string_int_label_map.proto.
* @param labelsResource location of the labels as a resource
* @return
*/
private static String[] loadLabels(Resource labelsResource) throws Exception {
try (InputStream is = labelsResource.getInputStream()) {
String text = StreamUtils.copyToString(is, Charset.forName("UTF-8"));
StringIntLabelMapOuterClass.StringIntLabelMap.Builder builder =
StringIntLabelMapOuterClass.StringIntLabelMap.newBuilder();
TextFormat.merge(text, builder);
StringIntLabelMapOuterClass.StringIntLabelMap proto = builder.build();
int maxLabelId = proto.getItemList().stream()
.map(StringIntLabelMapOuterClass.StringIntLabelMapItem::getId)
.max(Comparator.comparing(i -> i))
.orElse(-1);
String[] labelIdToNameMap = new String[maxLabelId + 1];
for (StringIntLabelMapOuterClass.StringIntLabelMapItem item : proto.getItemList()) {
if (!StringUtils.isEmpty(item.getDisplayName())) {
labelIdToNameMap[item.getId()] = item.getDisplayName();
}
else {
// Common practice is to set the name to a MID or Synsets Id. Synset is a set of synonyms that
// share a common meaning: https://en.wikipedia.org/wiki/WordNet
labelIdToNameMap[item.getId()] = item.getName();
}
}
return labelIdToNameMap;
}
}
@Override
public List<List<ObjectDetection>> apply(Map<String, Tensor<?>> tensorMap) {
try (Tensor<Float> scoresTensor = tensorMap.get(DETECTION_SCORES).expect(Float.class);
Tensor<Float> classesTensor = tensorMap.get(DETECTION_CLASSES).expect(Float.class);
Tensor<Float> boxesTensor = tensorMap.get(DETECTION_BOXES).expect(Float.class)
) {
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
// This can be verified by looking at scoresT.shape() etc.
int batchSize = (int) scoresTensor.shape()[0];
int maxObjects = (int) scoresTensor.shape()[1];
float[][] scores = scoresTensor.copyTo(new float[batchSize][maxObjects]);
float[][] classes = classesTensor.copyTo(new float[batchSize][maxObjects]);
float[][][] boxes = boxesTensor.copyTo(new float[batchSize][maxObjects][4]);
List<List<ObjectDetection>> batchObjectDetections = new ArrayList<>();
for (int batchIndex = 0; batchIndex < batchSize; batchIndex++) {
List<ObjectDetection> objectDetections = new ArrayList<>();
// Collect only the objects whose scores are at above the configured confidence threshold.
for (int i = 0; i < scores[batchIndex].length; ++i) {
if (scores[batchIndex][i] >= confidence) {
String category = labels[(int) classes[batchIndex][i]];
float score = scores[batchIndex][i];
ObjectDetection od = new ObjectDetection();
od.setName(category);
od.setConfidence(score);
od.setX1(boxes[batchIndex][i][1]);
od.setY1(boxes[batchIndex][i][0]);
od.setX2(boxes[batchIndex][i][3]);
od.setY2(boxes[batchIndex][i][2]);
od.setCid((int) classes[batchIndex][i]);
// Mask allows image-segmentation
if (modelFetch.contains(DETECTION_MASKS) && modelFetch.contains(NUM_DETECTIONS)) {
Tensor<Float> masksTensor = tensorMap.get(DETECTION_MASKS).expect(Float.class);
Tensor<Float> numDetections = tensorMap.get(NUM_DETECTIONS).expect(Float.class);
float nd = numDetections.copyTo(new float[batchSize])[0];
if (masksTensor != null) {
long[] shape = masksTensor.shape();
float[][][][] masks = masksTensor.copyTo(new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]);
od.setMask(masks[batchIndex][i]);
logger.debug(String.format("Num detections: %s, Masks: %s", nd, masks));
}
}
objectDetections.add(od);
}
}
batchObjectDetections.add(objectDetections);
}
return batchObjectDetections;
}
}
}

View File

@@ -0,0 +1,147 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.TensorFlowService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.util.StreamUtils;
/**
* Convenience class that leverages the the {@link ObjectDetectionInputConverter}, {@link ObjectDetectionOutputConverter} and {@link TensorFlowService}
* in combination fromMemory the Tensorflow Object Detection API (https://github.com/tensorflow/models/tree/master/research/object_detection)
* models for detection objects in input images.
*
* All pre-trained models (https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md) and labels are supported.
*
* You can download pre-trained models directly from the zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
* Just use the URI notation: (zoo model tar.gz url)#(name of the frozen model file name). To speedup the bootstrap
* performance you should consider downloading the models locally and use the file:/"path to my model" URI instead!
*
* The object category labels for the pre-trained models are available at: https://github.com/tensorflow/models/tree/master/research/object_detection/data
* Use the labels applicable for the model. Also, for performance reasons you may consider to download the labels
* and load them from file: instead.
*
* @author Christian Tzolov
*/
public class ObjectDetectionService {
/** Default list of fetch names for Box models. */
public static List<String> FETCH_NAMES = Arrays.asList(
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.NUM_DETECTIONS);
/** Default list of fetch names for mask supporting models. */
public static List<String> FETCH_NAMES_WITH_MASKS = Arrays.asList(
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.DETECTION_MASKS,
ObjectDetectionOutputConverter.NUM_DETECTIONS);
private final ObjectDetectionInputConverter inputConverter;
private final ObjectDetectionOutputConverter outputConverter;
private final TensorFlowService tensorFlowService;
public ObjectDetectionService() {
this("https://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz#frozen_inference_graph.pb",
"https://storage.googleapis.com/scdf-tensorflow-models/object-detection/mscoco_label_map.pbtxt",
0.4f, false, true);
}
/**
* Convenience constructor that would initialize all necessary internal components.
* @param modelUri URI of the pre-trained, frozen Tensorflow model.
* @param labelsUri URI of the pre-trained category labels.
* @param confidence Confidence threshold. Only objects detected wth confidence above this threshold will be returned.
* @param withMasks If a Mask model is selected then you can use this flag to extract the instance segmentation masks as well.
*/
public ObjectDetectionService(String modelUri, String labelsUri,
float confidence, boolean withMasks, boolean cacheModel) {
this.inputConverter = new ObjectDetectionInputConverter();
List<String> fetchNames = withMasks ? FETCH_NAMES_WITH_MASKS : FETCH_NAMES;
this.outputConverter = new ObjectDetectionOutputConverter(
new DefaultResourceLoader().getResource(labelsUri), confidence, fetchNames);
this.tensorFlowService = new TensorFlowService(
new DefaultResourceLoader().getResource(modelUri), fetchNames, cacheModel);
}
/**
* Generic constructor thea allow the converter to be pre-configured before passed to the service.
* @param inputConverter Converter from byte array to object detection input image tensor
* @param outputConverter Covets the object detection output tensors into {@link ObjectDetection } list
* @param tensorFlowService Java tensorflow runner instance
*/
public ObjectDetectionService(ObjectDetectionInputConverter inputConverter,
ObjectDetectionOutputConverter outputConverter, TensorFlowService tensorFlowService) {
this.inputConverter = inputConverter;
this.outputConverter = outputConverter;
this.tensorFlowService = tensorFlowService;
}
/**
* Detects objects in a single input image identified by its URI.
*
* @param imageUri input image's URI
* @return Returns a list of {@link ObjectDetection} domain objects representing detected objects
*/
public List<ObjectDetection> detect(String imageUri) {
try (InputStream is = new DefaultResourceLoader().getResource(imageUri).getInputStream()) {
return this.detect(StreamUtils.copyToByteArray(is));
}
catch (IOException e) {
e.printStackTrace();
throw new IllegalStateException("Failed to detect the image:" + imageUri, e);
}
}
/**
* Detects objects in a single {@link BufferedImage}.
*
* @param image Input image to detect objects from.
* @param format Image format (e.g. jpg, png ...) to use when converting the buffer into byte array.
* @return Returns a list of {@link ObjectDetection} domain objects representing detected objects in the input image
*/
public List<ObjectDetection> detect(BufferedImage image, String format) {
return this.detect(GraphicsUtils.toImageByteArray(image, format));
}
/**
* Detects objects from a single input image encoded as byte array.
*
* @param image Input image encoded as byte array
* @return Returns a list of {@link ObjectDetection} domain objects representing detected objects in the input image
*/
public List<ObjectDetection> detect(byte[] image) {
return this.inputConverter.andThen(this.tensorFlowService).andThen(this.outputConverter).apply(new byte[][] { image }).get(0);
}
/**
* Uses detects objects from a batch of input images encoded as byte array.
*
* @param batchedImages Batch of input images encoded as byte arrays. First dimension is the batch size and second the image bytes.
* @return Returns list of lists. For every input image in the batch a list of {@link ObjectDetection} domain objects representing detected objects in the input image.
*/
public List<List<ObjectDetection>> detect(byte[][] batchedImages) {
return this.inputConverter.andThen(this.tensorFlowService).andThen(this.outputConverter).apply(batchedImages);
}
}

View File

@@ -0,0 +1,113 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.image.DecodeJpeg;
import org.tensorflow.types.UInt8;
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
import org.springframework.cloud.fn.common.tensorflow.GraphRunnerMemory;
import org.springframework.cloud.fn.common.tensorflow.ProtoBufGraphDefinition;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
/**
* @author Christian Tzolov
*/
public class ObjectDetectionService2 implements AutoCloseable {
/** Default Box models fetch names. */
public static List<String> FETCH_NAMES = Arrays.asList(
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.NUM_DETECTIONS);
/** Default Models models fetch names. */
public static List<String> FETCH_NAMES_WITH_MASKS = Arrays.asList(
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.DETECTION_MASKS,
ObjectDetectionOutputConverter.NUM_DETECTIONS);
private final GraphRunner imageNormalization;
private final GraphRunner objectDetection;
private final ObjectDetectionOutputConverter outputConverter;
public ObjectDetectionService2(String modelUri, ObjectDetectionOutputConverter outputConverter) {
this.imageNormalization = new GraphRunner("raw_image", "normalized_image")
.withGraphDefinition(tf -> {
Placeholder<String> rawImage = tf.withName("raw_image").placeholder(String.class);
Operand<UInt8> decodedImage = tf.dtypes.cast(
tf.image.decodeJpeg(rawImage, DecodeJpeg.channels(3L)), UInt8.class);
// Expand dimensions since the model expects images to have shape: [1, H, W, 3]
tf.withName("normalized_image").expandDims(decodedImage, tf.constant(0));
});
this.objectDetection = new GraphRunner(Arrays.asList("image_tensor"), FETCH_NAMES)
.withGraphDefinition(new ProtoBufGraphDefinition(
new DefaultResourceLoader().getResource(modelUri), true));
this.outputConverter = outputConverter;
}
public List<ObjectDetection> detect(byte[] image) {
try (Tensor inputTensor = Tensor.create(image); GraphRunnerMemory memorize = new GraphRunnerMemory()) {
List<List<ObjectDetection>> out = this.imageNormalization.andThen(memorize)
.andThen(this.objectDetection).andThen(memorize)
.andThen(outputConverter)
.apply(Collections.singletonMap("raw_image", inputTensor));
return out.get(0);
}
}
@Override
public void close() {
this.imageNormalization.close();
this.objectDetection.close();
//this.outputConverter.close();
}
public static void main(String[] args) throws IOException {
String modelUri = "http://dl.bintray.com/big-data/generic/ssdlite_mobilenet_v2_coco_2018_05_09_frozen_inference_graph.pb";
String labelUri = "http://dl.bintray.com/big-data/generic/mscoco_label_map.pbtxt";
ObjectDetectionOutputConverter outputAdapter = new ObjectDetectionOutputConverter(
new DefaultResourceLoader().getResource(labelUri), 0.4f, FETCH_NAMES);
//byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg");
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/wild-animals-15.jpg");
try (ObjectDetectionService2 objectDetectionService2 = new ObjectDetectionService2(modelUri, outputAdapter)) {
List<ObjectDetection> boza = objectDetectionService2.detect(inputImage);
System.out.println(boza);
}
}
}

View File

@@ -0,0 +1,116 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.domain;
import java.util.Arrays;
import com.fasterxml.jackson.annotation.JsonInclude;
/**
* @author Christian Tzolov
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class ObjectDetection {
private String name;
private float confidence;
private float x1;
private float y1;
private float x2;
private float y2;
private float[][] mask;
private int cid;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public float getConfidence() {
return confidence;
}
public void setConfidence(float confidence) {
this.confidence = confidence;
}
public float getX1() {
return x1;
}
public void setX1(float x1) {
this.x1 = x1;
}
public float getY1() {
return y1;
}
public void setY1(float y1) {
this.y1 = y1;
}
public float getX2() {
return x2;
}
public void setX2(float x2) {
this.x2 = x2;
}
public float getY2() {
return y2;
}
public void setY2(float y2) {
this.y2 = y2;
}
public int getCid() {
return cid;
}
public void setCid(int cid) {
this.cid = cid;
}
public float[][] getMask() {
return mask;
}
public void setMask(float[][] mask) {
this.mask = mask;
}
@Override
public String toString() {
return "ObjectDetection{" +
"name='" + name + '\'' +
", confidence=" + confidence +
", x1=" + x1 +
", y1=" + y1 +
", x2=" + x2 +
", y2=" + y2 +
", mask=" + Arrays.toString(mask) +
", cid=" + cid +
'}';
}
}

View File

@@ -0,0 +1,301 @@
// Protocol messages for describing input data Examples for machine learning
// model training or inference.
syntax = "proto3";
import "feature.proto";
option cc_enable_arenas = true;
option java_outer_classname = "ExampleProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.example";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example";
package tensorflow;
// An Example is a mostly-normalized data format for storing data for
// training and inference. It contains a key-value store (features); where
// each key (string) maps to a Feature message (which is oneof packed BytesList,
// FloatList, or Int64List). This flexible and compact format allows the
// storage of large amounts of typed data, but requires that the data shape
// and use be determined by the configuration files and parsers that are used to
// read and write this format. That is, the Example is mostly *not* a
// self-describing format. In TensorFlow, Examples are read in row-major
// format, so any configuration that describes data with rank-2 or above
// should keep this in mind. For example, to store an M x N matrix of Bytes,
// the BytesList must contain M*N bytes, with M rows of N contiguous values
// each. That is, the BytesList value must store the matrix as:
// .... row 0 .... .... row 1 .... // ........... // ... row M-1 ....
//
// An Example for a movie recommendation application:
// features {
// feature {
// key: "age"
// value { float_list {
// value: 29.0
// }}
// }
// feature {
// key: "movie"
// value { bytes_list {
// value: "The Shawshank Redemption"
// value: "Fight Club"
// }}
// }
// feature {
// key: "movie_ratings"
// value { float_list {
// value: 9.0
// value: 9.7
// }}
// }
// feature {
// key: "suggestion"
// value { bytes_list {
// value: "Inception"
// }}
// }
// # Note that this feature exists to be used as a label in training.
// # E.g., if training a logistic regression model to predict purchase
// # probability in our learning tool we would set the label feature to
// # "suggestion_purchased".
// feature {
// key: "suggestion_purchased"
// value { float_list {
// value: 1.0
// }}
// }
// # Similar to "suggestion_purchased" above this feature exists to be used
// # as a label in training.
// # E.g., if training a linear regression model to predict purchase
// # price in our learning tool we would set the label feature to
// # "purchase_price".
// feature {
// key: "purchase_price"
// value { float_list {
// value: 9.99
// }}
// }
// }
//
// A conformant Example data set obeys the following conventions:
// - If a Feature K exists in one example with data type T, it must be of
// type T in all other examples when present. It may be omitted.
// - The number of instances of Feature K list data may vary across examples,
// depending on the requirements of the model.
// - If a Feature K doesn't exist in an example, a K-specific default will be
// used, if configured.
// - If a Feature K exists in an example but contains no items, the intent
// is considered to be an empty tensor and no default will be used.
message Example {
Features features = 1;
};
// A SequenceExample is an Example representing one or more sequences, and
// some context. The context contains features which apply to the entire
// example. The feature_lists contain a key, value map where each key is
// associated with a repeated set of Features (a FeatureList).
// A FeatureList thus represents the values of a feature identified by its key
// over time / frames.
//
// Below is a SequenceExample for a movie recommendation application recording a
// sequence of ratings by a user. The time-independent features ("locale",
// "age", "favorites") describing the user are part of the context. The sequence
// of movies the user rated are part of the feature_lists. For each movie in the
// sequence we have information on its name and actors and the user's rating.
// This information is recorded in three separate feature_list(s).
// In the example below there are only two movies. All three feature_list(s),
// namely "movie_ratings", "movie_names", and "actors" have a feature value for
// both movies. Note, that "actors" is itself a bytes_list with multiple
// strings per movie.
//
// context: {
// feature: {
// key : "locale"
// value: {
// bytes_list: {
// value: [ "pt_BR" ]
// }
// }
// }
// feature: {
// key : "age"
// value: {
// float_list: {
// value: [ 19.0 ]
// }
// }
// }
// feature: {
// key : "favorites"
// value: {
// bytes_list: {
// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ]
// }
// }
// }
// }
// feature_lists: {
// feature_list: {
// key : "movie_ratings"
// value: {
// feature: {
// float_list: {
// value: [ 4.5 ]
// }
// }
// feature: {
// float_list: {
// value: [ 5.0 ]
// }
// }
// }
// }
// feature_list: {
// key : "movie_names"
// value: {
// feature: {
// bytes_list: {
// value: [ "The Shawshank Redemption" ]
// }
// }
// feature: {
// bytes_list: {
// value: [ "Fight Club" ]
// }
// }
// }
// }
// feature_list: {
// key : "actors"
// value: {
// feature: {
// bytes_list: {
// value: [ "Tim Robbins", "Morgan Freeman" ]
// }
// }
// feature: {
// bytes_list: {
// value: [ "Brad Pitt", "Edward Norton", "Helena Bonham Carter" ]
// }
// }
// }
// }
// }
//
// A conformant SequenceExample data set obeys the following conventions:
//
// Context:
// - All conformant context features K must obey the same conventions as
// a conformant Example's features (see above).
// Feature lists:
// - A FeatureList L may be missing in an example; it is up to the
// parser configuration to determine if this is allowed or considered
// an empty list (zero length).
// - If a FeatureList L exists, it may be empty (zero length).
// - If a FeatureList L is non-empty, all features within the FeatureList
// must have the same data type T. Even across SequenceExamples, the type T
// of the FeatureList identified by the same key must be the same. An entry
// without any values may serve as an empty feature.
// - If a FeatureList L is non-empty, it is up to the parser configuration
// to determine if all features within the FeatureList must
// have the same size. The same holds for this FeatureList across multiple
// examples.
// - For sequence modeling, e.g.:
// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
// https://github.com/tensorflow/nmt
// the feature lists represent a sequence of frames.
// In this scenario, all FeatureLists in a SequenceExample have the same
// number of Feature messages, so that the ith element in each FeatureList
// is part of the ith frame (or time step).
// Examples of conformant and non-conformant examples' FeatureLists:
//
// Conformant FeatureLists:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
//
// Non-conformant FeatureLists (mismatched types):
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { int64_list: { value: [ 5 ] } } }
// } }
//
// Conditionally conformant FeatureLists, the parser configuration determines
// if the feature sizes must match:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0, 6.0 ] } } }
// } }
//
// Conformant pair of SequenceExample
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } }
// feature: { float_list: { value: [ 2.0 ] } } }
// } }
//
// Conformant pair of SequenceExample
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { }
// } }
//
// Conditionally conformant pair of SequenceExample, the parser configuration
// determines if the second feature_lists is consistent (zero-length) or
// invalid (missing "movie_ratings"):
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { }
//
// Non-conformant pair of SequenceExample (mismatched types)
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { int64_list: { value: [ 4 ] } }
// feature: { int64_list: { value: [ 5 ] } }
// feature: { int64_list: { value: [ 2 ] } } }
// } }
//
// Conditionally conformant pair of SequenceExample; the parser configuration
// determines if the feature sizes must match:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
// feature: { float_list: { value: [ 5.0 ] } } }
// } }
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.0 ] } }
// feature: { float_list: { value: [ 5.0, 3.0 ] } }
// } }
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};

View File

@@ -0,0 +1,105 @@
// Protocol messages for describing features for machine learning model
// training or inference.
//
// There are three base Feature types:
// - bytes
// - float
// - int64
//
// A Feature contains Lists which may hold zero or more values. These
// lists are the base values BytesList, FloatList, Int64List.
//
// Features are organized into categories by name. The Features message
// contains the mapping from name to Feature.
//
// Example Features for a movie recommendation application:
// feature {
// key: "age"
// value { float_list {
// value: 29.0
// }}
// }
// feature {
// key: "movie"
// value { bytes_list {
// value: "The Shawshank Redemption"
// value: "Fight Club"
// }}
// }
// feature {
// key: "movie_ratings"
// value { float_list {
// value: 9.0
// value: 9.7
// }}
// }
// feature {
// key: "suggestion"
// value { bytes_list {
// value: "Inception"
// }}
// }
// feature {
// key: "suggestion_purchased"
// value { int64_list {
// value: 1
// }}
// }
// feature {
// key: "purchase_price"
// value { float_list {
// value: 9.99
// }}
// }
//
syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "FeatureProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.example";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example";
package tensorflow;
// Containers to hold repeated fundamental values.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
// Containers for sequential data.
//
// A FeatureList contains lists of Features. These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name. The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};

View File

@@ -0,0 +1,25 @@
// Message to store the mapping from class label strings to class id. Datasets
// use string labels to represent classes while the object detection framework
// works fromMemory class ids. This message maps them so they can be converted back
// and forth as needed.
syntax = "proto2";
package org.springframework.cloud.fn.object.detection.protos;
message StringIntLabelMapItem {
// String name. The most common practice is to set this to a MID or synsets
// id. Synset: a set of synonyms that share a common meaning.
// https://en.wikipedia.org/wiki/WordNet
optional string name = 1;
// Integer id that maps to the string name above. Label ids should start
// from 1.
optional int32 id = 2;
// Human readable string label.
optional string display_name = 3;
};
message StringIntLabelMap {
repeated StringIntLabelMapItem item = 1;
};

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

View File

@@ -0,0 +1,89 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.examples;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.JsonMapperFunction;
import org.springframework.cloud.fn.object.detection.ObjectDetectionImageAugmenter;
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.ResourceLoader;
/**
* 4 of the pre-trained model in the model zoo (https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)
* can also compute the masks of the detected objects, providing instance segmentation.
*
* Here are the models that can be used for instance segmentation.
*
* mask_rcnn_inception_resnet_v2_atrous_coco 771 36 Masks
* mask_rcnn_inception_v2_coco 79 25 Masks
* mask_rcnn_resnet101_atrous_coco 470 33 Masks
* mask_rcnn_resnet50_atrous_coco 343 29 Masks
*
* @author Christian Tzolov
*/
public class ExampleInstanceSegmentation {
public static void main(String[] args) throws IOException {
ResourceLoader resourceLoader = new DefaultResourceLoader();
// You can download pre-trained models directly from the zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
// Just use the notation <zoo model tar.gz url>#<name of the frozen model file name>
// For performance reasons you may consider downloading the model locally and use the file:/<path to my model> URI instead!
String model = "http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28.tar.gz#frozen_inference_graph.pb";
// All labels for the pre-trained models are available at:
// https://github.com/tensorflow/models/tree/master/research/object_detection/data
// Use the labels applicable for the model.
// Also, for performance reasons you may consider to download the labels and load them from file: instead.
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
// You can cache the TF model on the local file system to improve the bootstrap performance on consecutive runs!
boolean CACHE_TF_MODEL = true;
// For the pre-trained models fromMemory mask you can set the INSTANCE_SEGMENTATION to enable object instance segmentation as well
boolean INSTANCE_SEGMENTATION = true;
// Only object fromMemory confidence above the threshold are returned
float CONFIDENCE_THRESHOLD = 0.4f;
ObjectDetectionService detectionService =
new ObjectDetectionService(model, labels, CONFIDENCE_THRESHOLD, INSTANCE_SEGMENTATION, CACHE_TF_MODEL);
// You can use file:, http: or classpath: to provide the path to the input image.
byte[] image = GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg");
// Returns a list ObjectDetection domain classes to allow programmatic accesses to the detected objects's metadata
List<ObjectDetection> detectedObjects = detectionService.detect(image);
// Get JSON representation of the detected objects
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects);
System.out.println(jsonObjectDetections);
// Draw the detected object metadata on top of the original image and store the result
byte[] annotatedImage = new ObjectDetectionImageAugmenter(INSTANCE_SEGMENTATION).apply(image, detectedObjects);
IOUtils.write(annotatedImage, new FileOutputStream("./object-detection-function/target/object-detection-segmentation-augmented.jpg"));
}
}

View File

@@ -0,0 +1,84 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.examples;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.springframework.cloud.fn.common.tensorflow.deprecated.JsonMapperFunction;
import org.springframework.cloud.fn.object.detection.ObjectDetectionImageAugmenter;
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.util.StreamUtils;
/**
* @author Christian Tzolov
*/
public class ExampleObjectDetection {
public static void main(String[] args) throws IOException {
// You can download pre-trained models directly from the zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
// Just use the notation <zoo model tar.gz url>#<name of the frozen model file name>
// For performance reasons you may consider downloading the model locally and use the file:/<path to my model> URI instead!
String model = "http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz#frozen_inference_graph.pb";
//Resource model = resourceLoader.getResource("http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_fgvc_2018_07_19.tar.gz#frozen_inference_graph.pb");
//Resource model = resourceLoader.getResource("http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_fgvc_2018_07_19.tar.gz#frozen_inference_graph.pb");
// All labels for the pre-trained models are available at:
// https://github.com/tensorflow/models/tree/master/research/object_detection/data
// Use the labels applicable for the model.
// Also, for performance reasons you may consider to download the labels and load them from file: instead.
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
//Resource labels = resourceLoader.getResource("https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/fgvc_2854_classes_label_map.pbtxt");
// You can cache the TF model on the local file system to improve the bootstrap performance on consecutive runs!
boolean CACHE_TF_MODEL = true;
// For the pre-trained models fromMemory mask you can set the INSTANCE_SEGMENTATION to enable object instance segmentation as well
boolean NO_INSTANCE_SEGMENTATION = false;
// Only object fromMemory confidence above the threshold are returned
float CONFIDENCE_THRESHOLD = 0.4f;
ObjectDetectionService detectionService =
new ObjectDetectionService(model, labels, CONFIDENCE_THRESHOLD, NO_INSTANCE_SEGMENTATION, CACHE_TF_MODEL);
// You can use file:, http: or classpath: to provide the path to the input image.
String inputImageUri = "classpath:/images/object-detection.jpg";
try (InputStream is = new DefaultResourceLoader().getResource(inputImageUri).getInputStream()) {
byte[] image = StreamUtils.copyToByteArray(is);
// Returns a list ObjectDetection domain classes to allow programmatic accesses to the detected objects's metadata
List<ObjectDetection> detectedObjects = detectionService.detect(image);
// Get JSON representation of the detected objects
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects);
System.out.println(jsonObjectDetections);
// Draw the detected object metadata on top of the original image and store the result
byte[] annotatedImage = new ObjectDetectionImageAugmenter(NO_INSTANCE_SEGMENTATION).apply(image, detectedObjects);
IOUtils.write(annotatedImage, new FileOutputStream("./object-detection-function/target/object-detection-augmented.jpg"));
}
}
}

View File

@@ -0,0 +1,47 @@
/*
* Copyright 2020-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.fn.object.detection.examples;
import java.util.List;
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
/**
* @author Christian Tzolov
*/
public class SimpleExample {
public static void main(String[] args) {
// Select a pre-trained model from the model zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
// Just use the notation <model zoo url>#<name of the frozen model file in the zoo's tar.gz>
String model = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz#frozen_inference_graph.pb";
// All labels for the pre-trained models are available at: https://github.com/tensorflow/models/tree/master/research/object_detection/data
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
ObjectDetectionService detectionService = new ObjectDetectionService(model, labels,
0.4f, // Only object fromMemory confidence above the threshold are returned. Confidence range is [0, 1].
false, // No instance segmentation
true); // cache the TF model locally
// You can use file:, http: or classpath: to provide the path to the input image.
List<ObjectDetection> detectedObjects = detectionService.detect("classpath:/images/object-detection.jpg");
detectedObjects.stream().map(o -> o.toString()).forEach(System.out::println);
}
}