Tensorflow functions and applications
* initial step
* Tensorflow models functional model redesign
-- Based on https://tzolov.github.io/mind-model-services
-- Resolves #5
* Add object detection processor README
* Add image recognition processor README
* Initial Tensorflow commonn README
* Initial Tensorflow commonn README
* Tensorflow common diagram
* Tensorflow docs code
* Tensorflow docs code snippets improve
* Tensorflow docs code snippets improve
* Tensorflow docs code snippets improve
* Tensorflow docs code snippets improve
* Add semantic segmentation function. add object detecteion function readme
* oo images
* Furether oo readme improvments
* Final obj detection readme fixes
* Add image recognition readme
* Add image recognition readme 2
* Semantic segmentation readme
* Segmentation readme
* Semantic segmentation readme 3
* Fix image recognition and object detcion app starter dependecies
* Add metadata for Tensorflow apps
This commit is contained in:
committed by
Soby Chacko
parent
3bb9e066b9
commit
dffb467da4
@@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection;
|
||||
|
||||
import java.awt.Color;
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.function.BiFunction;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
|
||||
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Augment the input image fromMemory detected object bounding boxes and categories.
|
||||
* For mask models and withMask set to true it draws the instance segmentation image as well.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ObjectDetectionImageAugmenter implements BiFunction<byte[], List<ObjectDetection>, byte[]> {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(ObjectDetectionImageAugmenter.class);
|
||||
|
||||
/** Make checkstyle happy. **/
|
||||
public static final String DEFAULT_IMAGE_FORMAT = "jpg";
|
||||
|
||||
private String imageFormat = DEFAULT_IMAGE_FORMAT;
|
||||
|
||||
private final boolean withMask;
|
||||
private boolean agnosticColors = false;
|
||||
|
||||
public ObjectDetectionImageAugmenter() {
|
||||
this(false);
|
||||
}
|
||||
|
||||
public ObjectDetectionImageAugmenter(boolean withMask) {
|
||||
this.withMask = withMask;
|
||||
}
|
||||
|
||||
public boolean isAgnosticColors() {
|
||||
return agnosticColors;
|
||||
}
|
||||
|
||||
public void setAgnosticColors(boolean agnosticColors) {
|
||||
this.agnosticColors = agnosticColors;
|
||||
}
|
||||
|
||||
public String getImageFormat() {
|
||||
return imageFormat;
|
||||
}
|
||||
|
||||
public void setImageFormat(String imageFormat) {
|
||||
this.imageFormat = imageFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] apply(byte[] imageBytes, List<ObjectDetection> objectDetections) {
|
||||
|
||||
if (!CollectionUtils.isEmpty(objectDetections)) {
|
||||
try {
|
||||
BufferedImage bufferedImage = ImageIO.read(new ByteArrayInputStream(imageBytes));
|
||||
|
||||
for (ObjectDetection od : objectDetections) {
|
||||
int y1 = (int) (od.getY1() * (float) bufferedImage.getHeight());
|
||||
int x1 = (int) (od.getX1() * (float) bufferedImage.getWidth());
|
||||
int y2 = (int) (od.getY2() * (float) bufferedImage.getHeight());
|
||||
int x2 = (int) (od.getX2() * (float) bufferedImage.getWidth());
|
||||
|
||||
int cid = od.getCid();
|
||||
|
||||
String labelName = od.getName();
|
||||
int probability = (int) (100 * od.getConfidence());
|
||||
String title = labelName + ": " + probability + "%";
|
||||
|
||||
GraphicsUtils.drawBoundingBox(bufferedImage, cid, title, x1, y1, x2, y2, this.agnosticColors);
|
||||
|
||||
if (this.withMask && od.getMask() != null) {
|
||||
float[][] mask = od.getMask();
|
||||
if (mask != null) {
|
||||
Color maskColor = this.agnosticColors ? null : GraphicsUtils.getClassColor(cid);
|
||||
BufferedImage maskImage = GraphicsUtils.createMaskImage(
|
||||
mask, x2 - x1, y2 - y1, maskColor);
|
||||
GraphicsUtils.overlayImages(bufferedImage, maskImage, x1, y1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imageBytes = GraphicsUtils.toImageByteArray(bufferedImage, this.getImageFormat());
|
||||
}
|
||||
catch (IOException e) {
|
||||
logger.error(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Null mend that QR image is found and not output message will be send.
|
||||
return imageBytes;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.tensorflow.Operand;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.op.core.Placeholder;
|
||||
import org.tensorflow.op.image.DecodeJpeg;
|
||||
import org.tensorflow.types.UInt8;
|
||||
|
||||
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
|
||||
|
||||
/**
|
||||
* Converts byte array image into a input Tensor for the Object Detection API.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ObjectDetectionInputAdapter implements Function<byte[], Map<String, Tensor<?>>>, AutoCloseable {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(ObjectDetectionInputAdapter.class);
|
||||
|
||||
/** Make checkstyle happy. **/
|
||||
public static final String RAW_IMAGE = "raw_image";
|
||||
/** Make checkstyle happy. **/
|
||||
public static final String NORMALIZED_IMAGE = "normalized_image";
|
||||
/** Make checkstyle happy. **/
|
||||
public static final long CHANNELS = 3;
|
||||
|
||||
private final GraphRunner imageLoaderGraph;
|
||||
|
||||
public ObjectDetectionInputAdapter() {
|
||||
|
||||
this.imageLoaderGraph = new GraphRunner(RAW_IMAGE, NORMALIZED_IMAGE)
|
||||
.withGraphDefinition(tf -> {
|
||||
Placeholder<String> rawImage = tf.withName(RAW_IMAGE).placeholder(String.class);
|
||||
Operand<UInt8> decodedImage = tf.dtypes.cast(
|
||||
tf.image.decodeJpeg(rawImage, DecodeJpeg.channels(CHANNELS)), UInt8.class);
|
||||
// Expand dimensions since the model expects images to have shape: [1, H, W, 3]
|
||||
tf.withName(NORMALIZED_IMAGE).expandDims(decodedImage, tf.constant(0));
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Tensor<?>> apply(byte[] inputImage) {
|
||||
try (Tensor inputTensor = Tensor.create(inputImage)) {
|
||||
return this.imageLoaderGraph.apply(Collections.singletonMap(RAW_IMAGE, inputTensor));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (this.imageLoaderGraph != null) {
|
||||
this.imageLoaderGraph.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.awt.image.DataBufferByte;
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import javax.imageio.ImageIO;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.types.UInt8;
|
||||
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
|
||||
|
||||
/**
|
||||
* Converts byte array image into a input Tensor for the Object Detection API. The computed image tensors uses the
|
||||
* 'image_tensor' model placeholder.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ObjectDetectionInputConverter implements Function<byte[][], Map<String, Tensor<?>>> {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(ObjectDetectionInputConverter.class);
|
||||
|
||||
private static final long CHANNELS = 3;
|
||||
|
||||
/** Make checkstyle happy. **/
|
||||
public static final String IMAGE_TENSOR_FEED_NAME = "image_tensor";
|
||||
|
||||
@Override
|
||||
public Map<String, Tensor<?>> apply(byte[][] input) {
|
||||
return Collections.singletonMap(IMAGE_TENSOR_FEED_NAME, makeImageTensor(input));
|
||||
}
|
||||
|
||||
private static Tensor<UInt8> makeImageTensor(byte[][] imageBytesArray) {
|
||||
try {
|
||||
int batchSize = imageBytesArray.length;
|
||||
ByteBuffer byteBuffer = null;
|
||||
long[] shape = null;
|
||||
for (int batchIndex = 0; batchIndex < batchSize; batchIndex++) {
|
||||
byte[] imageBytes = imageBytesArray[batchIndex];
|
||||
ByteArrayInputStream is = new ByteArrayInputStream(imageBytes);
|
||||
BufferedImage img = ImageIO.read(is);
|
||||
|
||||
if (img.getType() != BufferedImage.TYPE_3BYTE_BGR) {
|
||||
img = GraphicsUtils.toBufferedImageType(img, BufferedImage.TYPE_3BYTE_BGR);
|
||||
}
|
||||
|
||||
if (byteBuffer == null) {
|
||||
byteBuffer = ByteBuffer.allocate((int) (batchSize * img.getHeight() * img.getWidth() * CHANNELS));
|
||||
shape = new long[] { batchSize, img.getHeight(), img.getWidth(), CHANNELS };
|
||||
}
|
||||
|
||||
byte[] data = ((DataBufferByte) img.getData().getDataBuffer()).getData();
|
||||
// ImageIO.read produces BGR-encoded images, while the model expects RGB.
|
||||
bgrToRgb(data);
|
||||
byteBuffer.put(data);
|
||||
}
|
||||
byteBuffer.flip();
|
||||
|
||||
return Tensor.create(UInt8.class, shape, byteBuffer);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new IllegalArgumentException("Incorrect image format", e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static void bgrToRgb(byte[] data) {
|
||||
for (int i = 0; i < data.length; i += 3) {
|
||||
byte tmp = data[i];
|
||||
data[i] = data[i + 2];
|
||||
data[i + 2] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import com.google.protobuf.TextFormat;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.tensorflow.Tensor;
|
||||
|
||||
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
|
||||
import org.springframework.cloud.fn.object.detection.protos.StringIntLabelMapOuterClass;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StreamUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* Converts the Tensorflow Object Detection result into {@link ObjectDetection} list.
|
||||
* The pre-trained Object Detection models (http://bit.ly/2osxMAY) produce 3 tensor outputs:
|
||||
* (1) detection_classes - containing the ids of detected objects, (2) detection_scores - confidence probabilities of the
|
||||
* detected object and (3) detection_boxes - the object bounding boxes withing the images.
|
||||
*
|
||||
* The MASK based models provide to 2 additional tensors: (4) num_detections and (5) detection_masks.
|
||||
*
|
||||
* All outputs tensors are float arrays, having:
|
||||
* - 1 as the first dimension
|
||||
* - maxObjects as the second dimension
|
||||
* While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
|
||||
* This can be verified by looking at scoresT.shape() etc.
|
||||
*
|
||||
* The format detected classes (e.g. labels) names is defined by the 'string_int_labels_map.proto'. The input list
|
||||
* is available at: https://github.com/tensorflow/models/tree/master/research/object_detection/data
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ObjectDetectionOutputConverter implements Function<Map<String, Tensor<?>>, List<List<ObjectDetection>>> {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(ObjectDetectionOutputConverter.class);
|
||||
|
||||
/** DETECTION_CLASSES. */
|
||||
public static final String DETECTION_CLASSES = "detection_classes";
|
||||
/** DETECTION_SCORES. */
|
||||
public static final String DETECTION_SCORES = "detection_scores";
|
||||
/** DETECTION_BOXES. */
|
||||
public static final String DETECTION_BOXES = "detection_boxes";
|
||||
/** DETECTION_MASKS. */
|
||||
public static final String DETECTION_MASKS = "detection_masks";
|
||||
/** NUM_DETECTIONS. */
|
||||
public static final String NUM_DETECTIONS = "num_detections";
|
||||
|
||||
private final String[] labels;
|
||||
private float confidence;
|
||||
private List<String> modelFetch;
|
||||
|
||||
public ObjectDetectionOutputConverter(Resource labelsResource, float confidence, List<String> modelFetch) {
|
||||
this.confidence = confidence;
|
||||
this.modelFetch = modelFetch;
|
||||
try {
|
||||
this.labels = loadLabels(labelsResource);
|
||||
Assert.notNull(this.labels, String.format("Failed to initialize object labels [%s].", labelsResource));
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new RuntimeException(String.format("Failed to initialize object labels [%s].", labelsResource), e);
|
||||
}
|
||||
|
||||
logger.info(String.format("Object labels [%s] loaded.", labelsResource));
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads object labels in the string_int_label_map.proto.
|
||||
* @param labelsResource location of the labels as a resource
|
||||
* @return
|
||||
*/
|
||||
private static String[] loadLabels(Resource labelsResource) throws Exception {
|
||||
try (InputStream is = labelsResource.getInputStream()) {
|
||||
String text = StreamUtils.copyToString(is, Charset.forName("UTF-8"));
|
||||
StringIntLabelMapOuterClass.StringIntLabelMap.Builder builder =
|
||||
StringIntLabelMapOuterClass.StringIntLabelMap.newBuilder();
|
||||
TextFormat.merge(text, builder);
|
||||
StringIntLabelMapOuterClass.StringIntLabelMap proto = builder.build();
|
||||
|
||||
int maxLabelId = proto.getItemList().stream()
|
||||
.map(StringIntLabelMapOuterClass.StringIntLabelMapItem::getId)
|
||||
.max(Comparator.comparing(i -> i))
|
||||
.orElse(-1);
|
||||
|
||||
String[] labelIdToNameMap = new String[maxLabelId + 1];
|
||||
for (StringIntLabelMapOuterClass.StringIntLabelMapItem item : proto.getItemList()) {
|
||||
if (!StringUtils.isEmpty(item.getDisplayName())) {
|
||||
labelIdToNameMap[item.getId()] = item.getDisplayName();
|
||||
}
|
||||
else {
|
||||
// Common practice is to set the name to a MID or Synsets Id. Synset is a set of synonyms that
|
||||
// share a common meaning: https://en.wikipedia.org/wiki/WordNet
|
||||
labelIdToNameMap[item.getId()] = item.getName();
|
||||
}
|
||||
}
|
||||
return labelIdToNameMap;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<List<ObjectDetection>> apply(Map<String, Tensor<?>> tensorMap) {
|
||||
|
||||
try (Tensor<Float> scoresTensor = tensorMap.get(DETECTION_SCORES).expect(Float.class);
|
||||
Tensor<Float> classesTensor = tensorMap.get(DETECTION_CLASSES).expect(Float.class);
|
||||
Tensor<Float> boxesTensor = tensorMap.get(DETECTION_BOXES).expect(Float.class)
|
||||
) {
|
||||
// All these tensors have:
|
||||
// - 1 as the first dimension
|
||||
// - maxObjects as the second dimension
|
||||
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
|
||||
// This can be verified by looking at scoresT.shape() etc.
|
||||
int batchSize = (int) scoresTensor.shape()[0];
|
||||
int maxObjects = (int) scoresTensor.shape()[1];
|
||||
float[][] scores = scoresTensor.copyTo(new float[batchSize][maxObjects]);
|
||||
float[][] classes = classesTensor.copyTo(new float[batchSize][maxObjects]);
|
||||
float[][][] boxes = boxesTensor.copyTo(new float[batchSize][maxObjects][4]);
|
||||
|
||||
List<List<ObjectDetection>> batchObjectDetections = new ArrayList<>();
|
||||
|
||||
for (int batchIndex = 0; batchIndex < batchSize; batchIndex++) {
|
||||
|
||||
|
||||
List<ObjectDetection> objectDetections = new ArrayList<>();
|
||||
|
||||
// Collect only the objects whose scores are at above the configured confidence threshold.
|
||||
for (int i = 0; i < scores[batchIndex].length; ++i) {
|
||||
if (scores[batchIndex][i] >= confidence) {
|
||||
String category = labels[(int) classes[batchIndex][i]];
|
||||
float score = scores[batchIndex][i];
|
||||
|
||||
ObjectDetection od = new ObjectDetection();
|
||||
od.setName(category);
|
||||
od.setConfidence(score);
|
||||
od.setX1(boxes[batchIndex][i][1]);
|
||||
od.setY1(boxes[batchIndex][i][0]);
|
||||
od.setX2(boxes[batchIndex][i][3]);
|
||||
od.setY2(boxes[batchIndex][i][2]);
|
||||
od.setCid((int) classes[batchIndex][i]);
|
||||
|
||||
// Mask allows image-segmentation
|
||||
if (modelFetch.contains(DETECTION_MASKS) && modelFetch.contains(NUM_DETECTIONS)) {
|
||||
Tensor<Float> masksTensor = tensorMap.get(DETECTION_MASKS).expect(Float.class);
|
||||
Tensor<Float> numDetections = tensorMap.get(NUM_DETECTIONS).expect(Float.class);
|
||||
float nd = numDetections.copyTo(new float[batchSize])[0];
|
||||
|
||||
if (masksTensor != null) {
|
||||
long[] shape = masksTensor.shape();
|
||||
float[][][][] masks = masksTensor.copyTo(new float[(int) shape[0]][(int) shape[1]][(int) shape[2]][(int) shape[3]]);
|
||||
od.setMask(masks[batchIndex][i]);
|
||||
logger.debug(String.format("Num detections: %s, Masks: %s", nd, masks));
|
||||
}
|
||||
}
|
||||
|
||||
objectDetections.add(od);
|
||||
}
|
||||
}
|
||||
|
||||
batchObjectDetections.add(objectDetections);
|
||||
}
|
||||
return batchObjectDetections;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.TensorFlowService;
|
||||
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
import org.springframework.util.StreamUtils;
|
||||
|
||||
/**
|
||||
* Convenience class that leverages the the {@link ObjectDetectionInputConverter}, {@link ObjectDetectionOutputConverter} and {@link TensorFlowService}
|
||||
* in combination fromMemory the Tensorflow Object Detection API (https://github.com/tensorflow/models/tree/master/research/object_detection)
|
||||
* models for detection objects in input images.
|
||||
*
|
||||
* All pre-trained models (https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md) and labels are supported.
|
||||
*
|
||||
* You can download pre-trained models directly from the zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
|
||||
* Just use the URI notation: (zoo model tar.gz url)#(name of the frozen model file name). To speedup the bootstrap
|
||||
* performance you should consider downloading the models locally and use the file:/"path to my model" URI instead!
|
||||
*
|
||||
* The object category labels for the pre-trained models are available at: https://github.com/tensorflow/models/tree/master/research/object_detection/data
|
||||
* Use the labels applicable for the model. Also, for performance reasons you may consider to download the labels
|
||||
* and load them from file: instead.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ObjectDetectionService {
|
||||
|
||||
/** Default list of fetch names for Box models. */
|
||||
public static List<String> FETCH_NAMES = Arrays.asList(
|
||||
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
|
||||
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.NUM_DETECTIONS);
|
||||
|
||||
/** Default list of fetch names for mask supporting models. */
|
||||
public static List<String> FETCH_NAMES_WITH_MASKS = Arrays.asList(
|
||||
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
|
||||
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.DETECTION_MASKS,
|
||||
ObjectDetectionOutputConverter.NUM_DETECTIONS);
|
||||
|
||||
private final ObjectDetectionInputConverter inputConverter;
|
||||
private final ObjectDetectionOutputConverter outputConverter;
|
||||
private final TensorFlowService tensorFlowService;
|
||||
|
||||
public ObjectDetectionService() {
|
||||
this("https://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz#frozen_inference_graph.pb",
|
||||
"https://storage.googleapis.com/scdf-tensorflow-models/object-detection/mscoco_label_map.pbtxt",
|
||||
0.4f, false, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convenience constructor that would initialize all necessary internal components.
|
||||
* @param modelUri URI of the pre-trained, frozen Tensorflow model.
|
||||
* @param labelsUri URI of the pre-trained category labels.
|
||||
* @param confidence Confidence threshold. Only objects detected wth confidence above this threshold will be returned.
|
||||
* @param withMasks If a Mask model is selected then you can use this flag to extract the instance segmentation masks as well.
|
||||
*/
|
||||
public ObjectDetectionService(String modelUri, String labelsUri,
|
||||
float confidence, boolean withMasks, boolean cacheModel) {
|
||||
this.inputConverter = new ObjectDetectionInputConverter();
|
||||
List<String> fetchNames = withMasks ? FETCH_NAMES_WITH_MASKS : FETCH_NAMES;
|
||||
this.outputConverter = new ObjectDetectionOutputConverter(
|
||||
new DefaultResourceLoader().getResource(labelsUri), confidence, fetchNames);
|
||||
this.tensorFlowService = new TensorFlowService(
|
||||
new DefaultResourceLoader().getResource(modelUri), fetchNames, cacheModel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic constructor thea allow the converter to be pre-configured before passed to the service.
|
||||
* @param inputConverter Converter from byte array to object detection input image tensor
|
||||
* @param outputConverter Covets the object detection output tensors into {@link ObjectDetection } list
|
||||
* @param tensorFlowService Java tensorflow runner instance
|
||||
*/
|
||||
public ObjectDetectionService(ObjectDetectionInputConverter inputConverter,
|
||||
ObjectDetectionOutputConverter outputConverter, TensorFlowService tensorFlowService) {
|
||||
this.inputConverter = inputConverter;
|
||||
this.outputConverter = outputConverter;
|
||||
this.tensorFlowService = tensorFlowService;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects objects in a single input image identified by its URI.
|
||||
*
|
||||
* @param imageUri input image's URI
|
||||
* @return Returns a list of {@link ObjectDetection} domain objects representing detected objects
|
||||
*/
|
||||
public List<ObjectDetection> detect(String imageUri) {
|
||||
try (InputStream is = new DefaultResourceLoader().getResource(imageUri).getInputStream()) {
|
||||
return this.detect(StreamUtils.copyToByteArray(is));
|
||||
}
|
||||
catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
throw new IllegalStateException("Failed to detect the image:" + imageUri, e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects objects in a single {@link BufferedImage}.
|
||||
*
|
||||
* @param image Input image to detect objects from.
|
||||
* @param format Image format (e.g. jpg, png ...) to use when converting the buffer into byte array.
|
||||
* @return Returns a list of {@link ObjectDetection} domain objects representing detected objects in the input image
|
||||
*/
|
||||
public List<ObjectDetection> detect(BufferedImage image, String format) {
|
||||
return this.detect(GraphicsUtils.toImageByteArray(image, format));
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects objects from a single input image encoded as byte array.
|
||||
*
|
||||
* @param image Input image encoded as byte array
|
||||
* @return Returns a list of {@link ObjectDetection} domain objects representing detected objects in the input image
|
||||
*/
|
||||
public List<ObjectDetection> detect(byte[] image) {
|
||||
return this.inputConverter.andThen(this.tensorFlowService).andThen(this.outputConverter).apply(new byte[][] { image }).get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Uses detects objects from a batch of input images encoded as byte array.
|
||||
*
|
||||
* @param batchedImages Batch of input images encoded as byte arrays. First dimension is the batch size and second the image bytes.
|
||||
* @return Returns list of lists. For every input image in the batch a list of {@link ObjectDetection} domain objects representing detected objects in the input image.
|
||||
*/
|
||||
public List<List<ObjectDetection>> detect(byte[][] batchedImages) {
|
||||
return this.inputConverter.andThen(this.tensorFlowService).andThen(this.outputConverter).apply(batchedImages);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import org.tensorflow.Operand;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.op.core.Placeholder;
|
||||
import org.tensorflow.op.image.DecodeJpeg;
|
||||
import org.tensorflow.types.UInt8;
|
||||
|
||||
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
|
||||
import org.springframework.cloud.fn.common.tensorflow.GraphRunnerMemory;
|
||||
import org.springframework.cloud.fn.common.tensorflow.ProtoBufGraphDefinition;
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
|
||||
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ObjectDetectionService2 implements AutoCloseable {
|
||||
|
||||
/** Default Box models fetch names. */
|
||||
public static List<String> FETCH_NAMES = Arrays.asList(
|
||||
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
|
||||
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.NUM_DETECTIONS);
|
||||
|
||||
/** Default Models models fetch names. */
|
||||
public static List<String> FETCH_NAMES_WITH_MASKS = Arrays.asList(
|
||||
ObjectDetectionOutputConverter.DETECTION_SCORES, ObjectDetectionOutputConverter.DETECTION_CLASSES,
|
||||
ObjectDetectionOutputConverter.DETECTION_BOXES, ObjectDetectionOutputConverter.DETECTION_MASKS,
|
||||
ObjectDetectionOutputConverter.NUM_DETECTIONS);
|
||||
|
||||
private final GraphRunner imageNormalization;
|
||||
private final GraphRunner objectDetection;
|
||||
private final ObjectDetectionOutputConverter outputConverter;
|
||||
|
||||
|
||||
public ObjectDetectionService2(String modelUri, ObjectDetectionOutputConverter outputConverter) {
|
||||
|
||||
this.imageNormalization = new GraphRunner("raw_image", "normalized_image")
|
||||
.withGraphDefinition(tf -> {
|
||||
Placeholder<String> rawImage = tf.withName("raw_image").placeholder(String.class);
|
||||
Operand<UInt8> decodedImage = tf.dtypes.cast(
|
||||
tf.image.decodeJpeg(rawImage, DecodeJpeg.channels(3L)), UInt8.class);
|
||||
// Expand dimensions since the model expects images to have shape: [1, H, W, 3]
|
||||
tf.withName("normalized_image").expandDims(decodedImage, tf.constant(0));
|
||||
});
|
||||
|
||||
this.objectDetection = new GraphRunner(Arrays.asList("image_tensor"), FETCH_NAMES)
|
||||
.withGraphDefinition(new ProtoBufGraphDefinition(
|
||||
new DefaultResourceLoader().getResource(modelUri), true));
|
||||
|
||||
this.outputConverter = outputConverter;
|
||||
}
|
||||
|
||||
public List<ObjectDetection> detect(byte[] image) {
|
||||
try (Tensor inputTensor = Tensor.create(image); GraphRunnerMemory memorize = new GraphRunnerMemory()) {
|
||||
|
||||
List<List<ObjectDetection>> out = this.imageNormalization.andThen(memorize)
|
||||
.andThen(this.objectDetection).andThen(memorize)
|
||||
.andThen(outputConverter)
|
||||
.apply(Collections.singletonMap("raw_image", inputTensor));
|
||||
|
||||
return out.get(0);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
this.imageNormalization.close();
|
||||
this.objectDetection.close();
|
||||
//this.outputConverter.close();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
String modelUri = "http://dl.bintray.com/big-data/generic/ssdlite_mobilenet_v2_coco_2018_05_09_frozen_inference_graph.pb";
|
||||
String labelUri = "http://dl.bintray.com/big-data/generic/mscoco_label_map.pbtxt";
|
||||
|
||||
ObjectDetectionOutputConverter outputAdapter = new ObjectDetectionOutputConverter(
|
||||
new DefaultResourceLoader().getResource(labelUri), 0.4f, FETCH_NAMES);
|
||||
|
||||
//byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg");
|
||||
byte[] inputImage = GraphicsUtils.loadAsByteArray("classpath:/images/wild-animals-15.jpg");
|
||||
|
||||
try (ObjectDetectionService2 objectDetectionService2 = new ObjectDetectionService2(modelUri, outputAdapter)) {
|
||||
|
||||
List<ObjectDetection> boza = objectDetectionService2.detect(inputImage);
|
||||
|
||||
System.out.println(boza);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection.domain;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public class ObjectDetection {
|
||||
|
||||
private String name;
|
||||
private float confidence;
|
||||
private float x1;
|
||||
private float y1;
|
||||
private float x2;
|
||||
private float y2;
|
||||
private float[][] mask;
|
||||
private int cid;
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public float getConfidence() {
|
||||
return confidence;
|
||||
}
|
||||
|
||||
public void setConfidence(float confidence) {
|
||||
this.confidence = confidence;
|
||||
}
|
||||
|
||||
public float getX1() {
|
||||
return x1;
|
||||
}
|
||||
|
||||
public void setX1(float x1) {
|
||||
this.x1 = x1;
|
||||
}
|
||||
|
||||
public float getY1() {
|
||||
return y1;
|
||||
}
|
||||
|
||||
public void setY1(float y1) {
|
||||
this.y1 = y1;
|
||||
}
|
||||
|
||||
public float getX2() {
|
||||
return x2;
|
||||
}
|
||||
|
||||
public void setX2(float x2) {
|
||||
this.x2 = x2;
|
||||
}
|
||||
|
||||
public float getY2() {
|
||||
return y2;
|
||||
}
|
||||
|
||||
public void setY2(float y2) {
|
||||
this.y2 = y2;
|
||||
}
|
||||
|
||||
public int getCid() {
|
||||
return cid;
|
||||
}
|
||||
|
||||
public void setCid(int cid) {
|
||||
this.cid = cid;
|
||||
}
|
||||
|
||||
public float[][] getMask() {
|
||||
return mask;
|
||||
}
|
||||
|
||||
public void setMask(float[][] mask) {
|
||||
this.mask = mask;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ObjectDetection{" +
|
||||
"name='" + name + '\'' +
|
||||
", confidence=" + confidence +
|
||||
", x1=" + x1 +
|
||||
", y1=" + y1 +
|
||||
", x2=" + x2 +
|
||||
", y2=" + y2 +
|
||||
", mask=" + Arrays.toString(mask) +
|
||||
", cid=" + cid +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,301 @@
|
||||
// Protocol messages for describing input data Examples for machine learning
|
||||
// model training or inference.
|
||||
syntax = "proto3";
|
||||
|
||||
import "feature.proto";
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "ExampleProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.example";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example";
|
||||
package tensorflow;
|
||||
|
||||
// An Example is a mostly-normalized data format for storing data for
|
||||
// training and inference. It contains a key-value store (features); where
|
||||
// each key (string) maps to a Feature message (which is oneof packed BytesList,
|
||||
// FloatList, or Int64List). This flexible and compact format allows the
|
||||
// storage of large amounts of typed data, but requires that the data shape
|
||||
// and use be determined by the configuration files and parsers that are used to
|
||||
// read and write this format. That is, the Example is mostly *not* a
|
||||
// self-describing format. In TensorFlow, Examples are read in row-major
|
||||
// format, so any configuration that describes data with rank-2 or above
|
||||
// should keep this in mind. For example, to store an M x N matrix of Bytes,
|
||||
// the BytesList must contain M*N bytes, with M rows of N contiguous values
|
||||
// each. That is, the BytesList value must store the matrix as:
|
||||
// .... row 0 .... .... row 1 .... // ........... // ... row M-1 ....
|
||||
//
|
||||
// An Example for a movie recommendation application:
|
||||
// features {
|
||||
// feature {
|
||||
// key: "age"
|
||||
// value { float_list {
|
||||
// value: 29.0
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "movie"
|
||||
// value { bytes_list {
|
||||
// value: "The Shawshank Redemption"
|
||||
// value: "Fight Club"
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "movie_ratings"
|
||||
// value { float_list {
|
||||
// value: 9.0
|
||||
// value: 9.7
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "suggestion"
|
||||
// value { bytes_list {
|
||||
// value: "Inception"
|
||||
// }}
|
||||
// }
|
||||
// # Note that this feature exists to be used as a label in training.
|
||||
// # E.g., if training a logistic regression model to predict purchase
|
||||
// # probability in our learning tool we would set the label feature to
|
||||
// # "suggestion_purchased".
|
||||
// feature {
|
||||
// key: "suggestion_purchased"
|
||||
// value { float_list {
|
||||
// value: 1.0
|
||||
// }}
|
||||
// }
|
||||
// # Similar to "suggestion_purchased" above this feature exists to be used
|
||||
// # as a label in training.
|
||||
// # E.g., if training a linear regression model to predict purchase
|
||||
// # price in our learning tool we would set the label feature to
|
||||
// # "purchase_price".
|
||||
// feature {
|
||||
// key: "purchase_price"
|
||||
// value { float_list {
|
||||
// value: 9.99
|
||||
// }}
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// A conformant Example data set obeys the following conventions:
|
||||
// - If a Feature K exists in one example with data type T, it must be of
|
||||
// type T in all other examples when present. It may be omitted.
|
||||
// - The number of instances of Feature K list data may vary across examples,
|
||||
// depending on the requirements of the model.
|
||||
// - If a Feature K doesn't exist in an example, a K-specific default will be
|
||||
// used, if configured.
|
||||
// - If a Feature K exists in an example but contains no items, the intent
|
||||
// is considered to be an empty tensor and no default will be used.
|
||||
|
||||
message Example {
|
||||
Features features = 1;
|
||||
};
|
||||
|
||||
// A SequenceExample is an Example representing one or more sequences, and
|
||||
// some context. The context contains features which apply to the entire
|
||||
// example. The feature_lists contain a key, value map where each key is
|
||||
// associated with a repeated set of Features (a FeatureList).
|
||||
// A FeatureList thus represents the values of a feature identified by its key
|
||||
// over time / frames.
|
||||
//
|
||||
// Below is a SequenceExample for a movie recommendation application recording a
|
||||
// sequence of ratings by a user. The time-independent features ("locale",
|
||||
// "age", "favorites") describing the user are part of the context. The sequence
|
||||
// of movies the user rated are part of the feature_lists. For each movie in the
|
||||
// sequence we have information on its name and actors and the user's rating.
|
||||
// This information is recorded in three separate feature_list(s).
|
||||
// In the example below there are only two movies. All three feature_list(s),
|
||||
// namely "movie_ratings", "movie_names", and "actors" have a feature value for
|
||||
// both movies. Note, that "actors" is itself a bytes_list with multiple
|
||||
// strings per movie.
|
||||
//
|
||||
// context: {
|
||||
// feature: {
|
||||
// key : "locale"
|
||||
// value: {
|
||||
// bytes_list: {
|
||||
// value: [ "pt_BR" ]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// feature: {
|
||||
// key : "age"
|
||||
// value: {
|
||||
// float_list: {
|
||||
// value: [ 19.0 ]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// feature: {
|
||||
// key : "favorites"
|
||||
// value: {
|
||||
// bytes_list: {
|
||||
// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// feature_lists: {
|
||||
// feature_list: {
|
||||
// key : "movie_ratings"
|
||||
// value: {
|
||||
// feature: {
|
||||
// float_list: {
|
||||
// value: [ 4.5 ]
|
||||
// }
|
||||
// }
|
||||
// feature: {
|
||||
// float_list: {
|
||||
// value: [ 5.0 ]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// feature_list: {
|
||||
// key : "movie_names"
|
||||
// value: {
|
||||
// feature: {
|
||||
// bytes_list: {
|
||||
// value: [ "The Shawshank Redemption" ]
|
||||
// }
|
||||
// }
|
||||
// feature: {
|
||||
// bytes_list: {
|
||||
// value: [ "Fight Club" ]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// feature_list: {
|
||||
// key : "actors"
|
||||
// value: {
|
||||
// feature: {
|
||||
// bytes_list: {
|
||||
// value: [ "Tim Robbins", "Morgan Freeman" ]
|
||||
// }
|
||||
// }
|
||||
// feature: {
|
||||
// bytes_list: {
|
||||
// value: [ "Brad Pitt", "Edward Norton", "Helena Bonham Carter" ]
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// A conformant SequenceExample data set obeys the following conventions:
|
||||
//
|
||||
// Context:
|
||||
// - All conformant context features K must obey the same conventions as
|
||||
// a conformant Example's features (see above).
|
||||
// Feature lists:
|
||||
// - A FeatureList L may be missing in an example; it is up to the
|
||||
// parser configuration to determine if this is allowed or considered
|
||||
// an empty list (zero length).
|
||||
// - If a FeatureList L exists, it may be empty (zero length).
|
||||
// - If a FeatureList L is non-empty, all features within the FeatureList
|
||||
// must have the same data type T. Even across SequenceExamples, the type T
|
||||
// of the FeatureList identified by the same key must be the same. An entry
|
||||
// without any values may serve as an empty feature.
|
||||
// - If a FeatureList L is non-empty, it is up to the parser configuration
|
||||
// to determine if all features within the FeatureList must
|
||||
// have the same size. The same holds for this FeatureList across multiple
|
||||
// examples.
|
||||
// - For sequence modeling, e.g.:
|
||||
// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
|
||||
// https://github.com/tensorflow/nmt
|
||||
// the feature lists represent a sequence of frames.
|
||||
// In this scenario, all FeatureLists in a SequenceExample have the same
|
||||
// number of Feature messages, so that the ith element in each FeatureList
|
||||
// is part of the ith frame (or time step).
|
||||
// Examples of conformant and non-conformant examples' FeatureLists:
|
||||
//
|
||||
// Conformant FeatureLists:
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0 ] } } }
|
||||
// } }
|
||||
//
|
||||
// Non-conformant FeatureLists (mismatched types):
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { int64_list: { value: [ 5 ] } } }
|
||||
// } }
|
||||
//
|
||||
// Conditionally conformant FeatureLists, the parser configuration determines
|
||||
// if the feature sizes must match:
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0, 6.0 ] } } }
|
||||
// } }
|
||||
//
|
||||
// Conformant pair of SequenceExample
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0 ] } } }
|
||||
// } }
|
||||
// and:
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0 ] } }
|
||||
// feature: { float_list: { value: [ 2.0 ] } } }
|
||||
// } }
|
||||
//
|
||||
// Conformant pair of SequenceExample
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0 ] } } }
|
||||
// } }
|
||||
// and:
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { }
|
||||
// } }
|
||||
//
|
||||
// Conditionally conformant pair of SequenceExample, the parser configuration
|
||||
// determines if the second feature_lists is consistent (zero-length) or
|
||||
// invalid (missing "movie_ratings"):
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0 ] } } }
|
||||
// } }
|
||||
// and:
|
||||
// feature_lists: { }
|
||||
//
|
||||
// Non-conformant pair of SequenceExample (mismatched types)
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0 ] } } }
|
||||
// } }
|
||||
// and:
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { int64_list: { value: [ 4 ] } }
|
||||
// feature: { int64_list: { value: [ 5 ] } }
|
||||
// feature: { int64_list: { value: [ 2 ] } } }
|
||||
// } }
|
||||
//
|
||||
// Conditionally conformant pair of SequenceExample; the parser configuration
|
||||
// determines if the feature sizes must match:
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.5 ] } }
|
||||
// feature: { float_list: { value: [ 5.0 ] } } }
|
||||
// } }
|
||||
// and:
|
||||
// feature_lists: { feature_list: {
|
||||
// key: "movie_ratings"
|
||||
// value: { feature: { float_list: { value: [ 4.0 ] } }
|
||||
// feature: { float_list: { value: [ 5.0, 3.0 ] } }
|
||||
// } }
|
||||
|
||||
message SequenceExample {
|
||||
Features context = 1;
|
||||
FeatureLists feature_lists = 2;
|
||||
};
|
||||
@@ -0,0 +1,105 @@
|
||||
// Protocol messages for describing features for machine learning model
|
||||
// training or inference.
|
||||
//
|
||||
// There are three base Feature types:
|
||||
// - bytes
|
||||
// - float
|
||||
// - int64
|
||||
//
|
||||
// A Feature contains Lists which may hold zero or more values. These
|
||||
// lists are the base values BytesList, FloatList, Int64List.
|
||||
//
|
||||
// Features are organized into categories by name. The Features message
|
||||
// contains the mapping from name to Feature.
|
||||
//
|
||||
// Example Features for a movie recommendation application:
|
||||
// feature {
|
||||
// key: "age"
|
||||
// value { float_list {
|
||||
// value: 29.0
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "movie"
|
||||
// value { bytes_list {
|
||||
// value: "The Shawshank Redemption"
|
||||
// value: "Fight Club"
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "movie_ratings"
|
||||
// value { float_list {
|
||||
// value: 9.0
|
||||
// value: 9.7
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "suggestion"
|
||||
// value { bytes_list {
|
||||
// value: "Inception"
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "suggestion_purchased"
|
||||
// value { int64_list {
|
||||
// value: 1
|
||||
// }}
|
||||
// }
|
||||
// feature {
|
||||
// key: "purchase_price"
|
||||
// value { float_list {
|
||||
// value: 9.99
|
||||
// }}
|
||||
// }
|
||||
//
|
||||
|
||||
syntax = "proto3";
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "FeatureProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.example";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example";
|
||||
package tensorflow;
|
||||
|
||||
// Containers to hold repeated fundamental values.
|
||||
message BytesList {
|
||||
repeated bytes value = 1;
|
||||
}
|
||||
message FloatList {
|
||||
repeated float value = 1 [packed = true];
|
||||
}
|
||||
message Int64List {
|
||||
repeated int64 value = 1 [packed = true];
|
||||
}
|
||||
|
||||
// Containers for non-sequential data.
|
||||
message Feature {
|
||||
// Each feature can be exactly one kind.
|
||||
oneof kind {
|
||||
BytesList bytes_list = 1;
|
||||
FloatList float_list = 2;
|
||||
Int64List int64_list = 3;
|
||||
}
|
||||
};
|
||||
|
||||
message Features {
|
||||
// Map from feature name to feature.
|
||||
map<string, Feature> feature = 1;
|
||||
};
|
||||
|
||||
// Containers for sequential data.
|
||||
//
|
||||
// A FeatureList contains lists of Features. These may hold zero or more
|
||||
// Feature values.
|
||||
//
|
||||
// FeatureLists are organized into categories by name. The FeatureLists message
|
||||
// contains the mapping from name to FeatureList.
|
||||
//
|
||||
message FeatureList {
|
||||
repeated Feature feature = 1;
|
||||
};
|
||||
|
||||
message FeatureLists {
|
||||
// Map from feature name to feature list.
|
||||
map<string, FeatureList> feature_list = 1;
|
||||
};
|
||||
@@ -0,0 +1,25 @@
|
||||
// Message to store the mapping from class label strings to class id. Datasets
|
||||
// use string labels to represent classes while the object detection framework
|
||||
// works fromMemory class ids. This message maps them so they can be converted back
|
||||
// and forth as needed.
|
||||
syntax = "proto2";
|
||||
|
||||
package org.springframework.cloud.fn.object.detection.protos;
|
||||
|
||||
message StringIntLabelMapItem {
|
||||
// String name. The most common practice is to set this to a MID or synsets
|
||||
// id. Synset: a set of synonyms that share a common meaning.
|
||||
// https://en.wikipedia.org/wiki/WordNet
|
||||
optional string name = 1;
|
||||
|
||||
// Integer id that maps to the string name above. Label ids should start
|
||||
// from 1.
|
||||
optional int32 id = 2;
|
||||
|
||||
// Human readable string label.
|
||||
optional string display_name = 3;
|
||||
};
|
||||
|
||||
message StringIntLabelMap {
|
||||
repeated StringIntLabelMapItem item = 1;
|
||||
};
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 144 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 140 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 48 KiB |
@@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection.examples;
|
||||
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.JsonMapperFunction;
|
||||
import org.springframework.cloud.fn.object.detection.ObjectDetectionImageAugmenter;
|
||||
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
|
||||
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
import org.springframework.core.io.ResourceLoader;
|
||||
|
||||
/**
|
||||
* 4 of the pre-trained model in the model zoo (https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)
|
||||
* can also compute the masks of the detected objects, providing instance segmentation.
|
||||
*
|
||||
* Here are the models that can be used for instance segmentation.
|
||||
*
|
||||
* mask_rcnn_inception_resnet_v2_atrous_coco 771 36 Masks
|
||||
* mask_rcnn_inception_v2_coco 79 25 Masks
|
||||
* mask_rcnn_resnet101_atrous_coco 470 33 Masks
|
||||
* mask_rcnn_resnet50_atrous_coco 343 29 Masks
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ExampleInstanceSegmentation {
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
|
||||
ResourceLoader resourceLoader = new DefaultResourceLoader();
|
||||
|
||||
// You can download pre-trained models directly from the zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
|
||||
// Just use the notation <zoo model tar.gz url>#<name of the frozen model file name>
|
||||
// For performance reasons you may consider downloading the model locally and use the file:/<path to my model> URI instead!
|
||||
String model = "http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_resnet_v2_atrous_coco_2018_01_28.tar.gz#frozen_inference_graph.pb";
|
||||
|
||||
// All labels for the pre-trained models are available at:
|
||||
// https://github.com/tensorflow/models/tree/master/research/object_detection/data
|
||||
// Use the labels applicable for the model.
|
||||
// Also, for performance reasons you may consider to download the labels and load them from file: instead.
|
||||
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
|
||||
|
||||
// You can cache the TF model on the local file system to improve the bootstrap performance on consecutive runs!
|
||||
boolean CACHE_TF_MODEL = true;
|
||||
|
||||
// For the pre-trained models fromMemory mask you can set the INSTANCE_SEGMENTATION to enable object instance segmentation as well
|
||||
boolean INSTANCE_SEGMENTATION = true;
|
||||
|
||||
// Only object fromMemory confidence above the threshold are returned
|
||||
float CONFIDENCE_THRESHOLD = 0.4f;
|
||||
|
||||
ObjectDetectionService detectionService =
|
||||
new ObjectDetectionService(model, labels, CONFIDENCE_THRESHOLD, INSTANCE_SEGMENTATION, CACHE_TF_MODEL);
|
||||
|
||||
// You can use file:, http: or classpath: to provide the path to the input image.
|
||||
byte[] image = GraphicsUtils.loadAsByteArray("classpath:/images/object-detection.jpg");
|
||||
|
||||
// Returns a list ObjectDetection domain classes to allow programmatic accesses to the detected objects's metadata
|
||||
List<ObjectDetection> detectedObjects = detectionService.detect(image);
|
||||
|
||||
// Get JSON representation of the detected objects
|
||||
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects);
|
||||
System.out.println(jsonObjectDetections);
|
||||
|
||||
// Draw the detected object metadata on top of the original image and store the result
|
||||
byte[] annotatedImage = new ObjectDetectionImageAugmenter(INSTANCE_SEGMENTATION).apply(image, detectedObjects);
|
||||
IOUtils.write(annotatedImage, new FileOutputStream("./object-detection-function/target/object-detection-segmentation-augmented.jpg"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection.examples;
|
||||
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
|
||||
import org.springframework.cloud.fn.common.tensorflow.deprecated.JsonMapperFunction;
|
||||
import org.springframework.cloud.fn.object.detection.ObjectDetectionImageAugmenter;
|
||||
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
|
||||
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
import org.springframework.util.StreamUtils;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ExampleObjectDetection {
|
||||
|
||||
public static void main(String[] args) throws IOException {
|
||||
|
||||
// You can download pre-trained models directly from the zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
|
||||
// Just use the notation <zoo model tar.gz url>#<name of the frozen model file name>
|
||||
// For performance reasons you may consider downloading the model locally and use the file:/<path to my model> URI instead!
|
||||
String model = "http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2018_01_28.tar.gz#frozen_inference_graph.pb";
|
||||
//Resource model = resourceLoader.getResource("http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_fgvc_2018_07_19.tar.gz#frozen_inference_graph.pb");
|
||||
//Resource model = resourceLoader.getResource("http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_fgvc_2018_07_19.tar.gz#frozen_inference_graph.pb");
|
||||
|
||||
// All labels for the pre-trained models are available at:
|
||||
// https://github.com/tensorflow/models/tree/master/research/object_detection/data
|
||||
// Use the labels applicable for the model.
|
||||
// Also, for performance reasons you may consider to download the labels and load them from file: instead.
|
||||
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
|
||||
//Resource labels = resourceLoader.getResource("https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/fgvc_2854_classes_label_map.pbtxt");
|
||||
|
||||
// You can cache the TF model on the local file system to improve the bootstrap performance on consecutive runs!
|
||||
boolean CACHE_TF_MODEL = true;
|
||||
|
||||
// For the pre-trained models fromMemory mask you can set the INSTANCE_SEGMENTATION to enable object instance segmentation as well
|
||||
boolean NO_INSTANCE_SEGMENTATION = false;
|
||||
|
||||
// Only object fromMemory confidence above the threshold are returned
|
||||
float CONFIDENCE_THRESHOLD = 0.4f;
|
||||
|
||||
ObjectDetectionService detectionService =
|
||||
new ObjectDetectionService(model, labels, CONFIDENCE_THRESHOLD, NO_INSTANCE_SEGMENTATION, CACHE_TF_MODEL);
|
||||
|
||||
// You can use file:, http: or classpath: to provide the path to the input image.
|
||||
String inputImageUri = "classpath:/images/object-detection.jpg";
|
||||
try (InputStream is = new DefaultResourceLoader().getResource(inputImageUri).getInputStream()) {
|
||||
|
||||
byte[] image = StreamUtils.copyToByteArray(is);
|
||||
|
||||
// Returns a list ObjectDetection domain classes to allow programmatic accesses to the detected objects's metadata
|
||||
List<ObjectDetection> detectedObjects = detectionService.detect(image);
|
||||
|
||||
// Get JSON representation of the detected objects
|
||||
String jsonObjectDetections = new JsonMapperFunction().apply(detectedObjects);
|
||||
System.out.println(jsonObjectDetections);
|
||||
|
||||
// Draw the detected object metadata on top of the original image and store the result
|
||||
byte[] annotatedImage = new ObjectDetectionImageAugmenter(NO_INSTANCE_SEGMENTATION).apply(image, detectedObjects);
|
||||
IOUtils.write(annotatedImage, new FileOutputStream("./object-detection-function/target/object-detection-augmented.jpg"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright 2020-2020 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.cloud.fn.object.detection.examples;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.cloud.fn.object.detection.ObjectDetectionService;
|
||||
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class SimpleExample {
|
||||
|
||||
public static void main(String[] args) {
|
||||
// Select a pre-trained model from the model zoo: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
|
||||
// Just use the notation <model zoo url>#<name of the frozen model file in the zoo's tar.gz>
|
||||
String model = "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz#frozen_inference_graph.pb";
|
||||
|
||||
// All labels for the pre-trained models are available at: https://github.com/tensorflow/models/tree/master/research/object_detection/data
|
||||
String labels = "https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt";
|
||||
|
||||
ObjectDetectionService detectionService = new ObjectDetectionService(model, labels,
|
||||
0.4f, // Only object fromMemory confidence above the threshold are returned. Confidence range is [0, 1].
|
||||
false, // No instance segmentation
|
||||
true); // cache the TF model locally
|
||||
|
||||
// You can use file:, http: or classpath: to provide the path to the input image.
|
||||
List<ObjectDetection> detectedObjects = detectionService.detect("classpath:/images/object-detection.jpg");
|
||||
detectedObjects.stream().map(o -> o.toString()).forEach(System.out::println);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user