Add local, Transformers EmbeddingClient

- EmbeddingClient implementation that computes, locally, sentence embeddings with SBERT transformers.
 - Uses pre-trained transformer models, serialized into Open Neural Network Exchange (ONNX) format.
 - Deep Java Library and the Microsoft ONNX Java Runtime are used to run
   the ONNX models and compute the embeddings efficiently.
 - Add default tokenizer.json and model.onnx for sentence-transformers/all-MiniLM-L6-v2.
 - Add, configurable resource caching service to allow caching
   remote (http/https) resources to the local FS.
 - README.md provides information on how to serialize ONNX models.
 - add Git LFS configuration for large onnx model files.
This commit is contained in:
Christian Tzolov
2023-10-24 09:25:18 +02:00
committed by Christian Tzolov
parent e68bdeb9a0
commit 6030cda598
13 changed files with 31652 additions and 0 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
*.onnx filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,71 @@
# Local Transformers Embedding Client
The `TransformersEmbeddingClient` is a `EmbeddingClient` implementation that computes, locally, [sentence embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html#sentence-embeddings-with-transformers) using a selected [sentence transformer](https://www.sbert.net/).
It uses [pre-trained](https://www.sbert.net/docs/pretrained_models.html) transformer models, serialized into the [Open Neural Network Exchange (ONNX)](https://onnx.ai/) format.
The [Deep Java Library](https://djl.ai/) and the Microsoft [ONNX Java Runtime](https://onnxruntime.ai/docs/get-started/with-java.html) libraries are applied to run the ONNX models and compute the embeddings in Java.
## Serialize the Tokenizer and the Transformer Model
To run things in Java, we need to serialize the Tokenizer and the Transformer Model into ONNX format.
### Serialize with optimum-cli
One, quick, way to achieve this, is to use the [optimum-cli](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command line tool.
Following snippet creates an python virtual environment, installs the required packages and runs the optimum-cli to serialize (e.g. export) the models:
```bash
python3 -m venv venv
source ./venv/bin/activate
(venv) pip install --upgrade pip
(venv) pip install optimum onnx onnxruntime
(venv) optimum-cli export onnx --model sentence-transformers/all-MiniLM-L6-v2 onnx-output-folder
```
The `optimum-cli` command exports the [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) transformer into the `onnx-output-folder` folder. Later includes the `tokenizer.json` and `model.onnx` files used by the embedding client.
## Apply the ONNX model
Use the `setTokenizerResource(tokenizerJsonUri)` and `setModelResource(modelOnnxUri)` methods to set the URI locations of the exported `tokenizer.json` and `model.onnx` files.
The `classpath:`, `file:` or `https:` URI schemas are supported.
If no other model is explicitly set, the `TransformersEmbeddingClient` defaults to [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model:
| | |
| -------- | ------- |
| Dimensions |384 |
| Avg. performance | 58.80 |
| Speed | 14200 sentences/sec |
| Size | 80MB |
Following snippet illustrates how to use the `TransformersEmbeddingClient`:
```java
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
// (optional) defaults to classpath:/onnx/all-MiniLM-L6-v2/tokenizer.json
embeddingClient.setTokenizerResource("classpath:/onnx/all-MiniLM-L6-v2/tokenizer.json");
// (optional) defaults to classpath:/onnx/all-MiniLM-L6-v2/model.onnx
embeddingClient.setModelResource("classpath:/onnx/all-MiniLM-L6-v2/model.onnx");
// (optional) defaults to ${java.io.tmpdir}/spring-ai-onnx-model
// Only the http/https resources are cached by default.
embeddingClient.setResourceCacheDirectory("/tmp/onnx-zoo");
embeddingClient.afterPropertiesSet();
List<List<Double>> embeddings =
embeddingClient.embed(List.of("Hello world", "World is big"));
```

View File

@@ -0,0 +1,86 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.experimental.ai</groupId>
<artifactId>spring-ai</artifactId>
<version>0.7.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>transformers-embedding</artifactId>
<packaging>jar</packaging>
<name>Spring AI Embedding Client - Sentence Transormers Embeddings </name>
<description>Spring AI Sentence Transformers Embedding Client</description>
<url>https://github.com/spring-projects-experimental/spring-ai</url>
<scm>
<url>https://github.com/spring-projects-experimental/spring-ai</url>
<connection>git://github.com/spring-projects-experimental/spring-ai.git</connection>
<developerConnection>git@github.com:spring-projects-experimental/spring-ai.git</developerConnection>
</scm>
<properties>
<djl.version>0.24.0</djl.version>
<onnxruntime.version>1.16.1</onnxruntime.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.experimental.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>${parent.version}</version>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>${onnxruntime.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- TESTING -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,149 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.embedding;
import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.FileUrlResource;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
/**
* Service that helps caching remote {@link Resource}s on the local file system.
*
* @author Christian Tzolov
*/
public class ResourceCacheService {
private static final Log logger = LogFactory.getLog(ResourceCacheService.class);
/**
* The parent folder that contains all cached resources.
*/
private final File cacheDirectory;
/**
* Resources with URI schemas belonging to the excludedUriSchemas are not cached. By
* default the file and classpath resources are not cached as they are already in the
* local file system.
*/
private List<String> excludedUriSchemas = new ArrayList<>(List.of("file", "classpath"));
public ResourceCacheService() {
this(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-model").getAbsolutePath());
}
public ResourceCacheService(String rootCacheDirectory) {
this(new File(rootCacheDirectory));
}
public ResourceCacheService(File rootCacheDirectory) {
Assert.notNull(rootCacheDirectory, "Cache directory can not be null.");
this.cacheDirectory = rootCacheDirectory;
if (!this.cacheDirectory.exists()) {
logger.info("Create cache root directory: " + this.cacheDirectory.getAbsolutePath());
this.cacheDirectory.mkdirs();
}
Assert.isTrue(this.cacheDirectory.isDirectory(), "The cache folder must be a directory");
}
/**
* Overrides the excluded URI schemas list.
* @param excludedUriSchemas new list of URI schemas to be excluded from caching.
*/
public void setExcludedUriSchemas(List<String> excludedUriSchemas) {
Assert.notNull(excludedUriSchemas, "The excluded URI schemas list can not be null");
this.excludedUriSchemas = excludedUriSchemas;
}
/**
* Get {@link Resource} representing the cached copy of the original resource.
* @param originalResourceUri Resource to be cached.
* @return Returns a cached resource. If the original resource's URI schema is within
* the excluded schema list the original resource is returned.
*/
public Resource getCachedResource(String originalResourceUri) {
return this.getCachedResource(new DefaultResourceLoader().getResource(originalResourceUri));
}
/**
* Get {@link Resource} representing the cached copy of the original resource.
* @param originalResource Resource to be cached.
* @return Returns a cached resource. If the original resource's URI schema is within
* the excluded schema list the original resource is returned.
*/
public Resource getCachedResource(Resource originalResource) {
try {
if (this.excludedUriSchemas.contains(originalResource.getURI().getScheme())) {
logger.info("The " + originalResource.toString() + " resource with URI schema ["
+ originalResource.getURI().getScheme() + "] is excluded from caching");
return originalResource;
}
File cachedFile = getCachedFile(originalResource);
if (!cachedFile.exists()) {
FileCopyUtils.copy(StreamUtils.copyToByteArray(originalResource.getInputStream()), cachedFile);
logger.info("Caching the " + originalResource.toString() + " resource to: " + cachedFile);
}
return new FileUrlResource(cachedFile.getAbsolutePath());
}
catch (Exception e) {
throw new IllegalStateException("Failed to cache the resource: " + originalResource.getDescription(), e);
}
}
private File getCachedFile(Resource originalResource) throws IOException {
var resourceParentFolder = new File(this.cacheDirectory,
UUID.nameUUIDFromBytes(pathWithoutLastSegment(originalResource.getURI())).toString());
resourceParentFolder.mkdirs();
String newFileName = getCacheName(originalResource);
return new File(resourceParentFolder, newFileName);
}
private byte[] pathWithoutLastSegment(URI uri) {
String path = uri.toASCIIString();
var pathBeforeLastSegment = path.substring(0, path.lastIndexOf('/') + 1);
return pathBeforeLastSegment.getBytes();
}
private String getCacheName(Resource originalResource) throws IOException {
String fileName = originalResource.getFilename();
String fragment = originalResource.getURI().getFragment();
return !StringUtils.hasText(fragment) ? fileName : fileName + "_" + fragment;
}
public void deleteCacheFolder() {
if (this.cacheDirectory.exists()) {
logger.info("Empty Model Cache at:" + this.cacheDirectory.getAbsolutePath());
this.cacheDirectory.delete();
this.cacheDirectory.mkdirs();
}
}
}

View File

@@ -0,0 +1,274 @@
package org.springframework.ai.embedding;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* https://www.sbert.net/index.html https://www.sbert.net/docs/pretrained_models.html
*
* @author Christian Tzolov
*/
public class TransformersEmbeddingClient implements EmbeddingClient, InitializingBean {
// ONNX tokenizer for the all-MiniLM-L6-v2 model
private final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects-experimental/spring-ai/main/embedding-clients/transformers-embedding/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
// ONNX model for all-MiniLM-L6-v2 pre-trained transformer:
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
private final static String DEFAULT_ONNX_MODEL_URI = "https://raw.githubusercontent.com/spring-projects-experimental/spring-ai/main/embedding-clients/transformers-embedding/src/main/resources/onnx/all-MiniLM-L6-v2/model.onnx";
private final static int EMBEDDING_AXIS = 1;
private Resource tokenizerResource = toResource(DEFAULT_ONNX_TOKENIZER_URI);
private Resource modelResource = toResource(DEFAULT_ONNX_MODEL_URI);
private int gpuDeviceId = -1;
/**
* DJL, Huggingface tokenizer implementation of the {@link Tokenizer} interface that
* converts sentences into token.
*/
private HuggingFaceTokenizer tokenizer;
/**
* ONNX runtime configurations: https://onnxruntime.ai/docs/get-started/with-java.html
*/
private OrtEnvironment environment;
private OrtSession session;
private final AtomicInteger embeddingDimensions = new AtomicInteger(-1);
private final MetadataMode metadataMode;
/**
* Resource cache directory. Used to cache remote resources, such as the ONNX models,
* to the local file system.
*/
private String resourceCacheDirectory;
/**
* Allow disabling the resource caching.
*/
private boolean disableCaching = false;
private ResourceCacheService cache;
public TransformersEmbeddingClient() {
this(MetadataMode.NONE);
}
public TransformersEmbeddingClient(MetadataMode metadataMode) {
Assert.notNull(metadataMode, "Metadata mode should not be null");
this.metadataMode = metadataMode;
}
public void setDisableCaching(boolean disableCaching) {
this.disableCaching = disableCaching;
}
public void setResourceCacheDirectory(String resourceCacheDir) {
this.resourceCacheDirectory = resourceCacheDir;
}
public void setGpuDeviceId(int gpuDeviceId) {
this.gpuDeviceId = gpuDeviceId;
}
public void setTokenizerResource(Resource tokenizerResource) {
this.tokenizerResource = tokenizerResource;
}
public void setModelResource(Resource modelResource) {
this.modelResource = modelResource;
}
public void setTokenizerResource(String tokenizerResourceUri) {
this.tokenizerResource = toResource(tokenizerResourceUri);
}
public void setModelResource(String modelResourceUri) {
this.modelResource = toResource(modelResourceUri);
}
public void setEmbeddingDimensions(int dimension) {
this.embeddingDimensions.set(dimension);
}
@Override
public void afterPropertiesSet() throws Exception {
this.cache = StringUtils.hasText(this.resourceCacheDirectory)
? new ResourceCacheService(this.resourceCacheDirectory) : new ResourceCacheService();
this.tokenizer = HuggingFaceTokenizer.newInstance(getCachedResource(this.tokenizerResource).getInputStream(),
Map.of());
this.environment = OrtEnvironment.getEnvironment();
var sessionOptions = new OrtSession.SessionOptions();
if (this.gpuDeviceId >= 0) {
// Run on a GPU or with another provider
sessionOptions.addCUDA(this.gpuDeviceId);
}
this.session = this.environment.createSession(getCachedResource(this.modelResource).getContentAsByteArray(),
sessionOptions);
}
private Resource getCachedResource(Resource resource) {
return this.disableCaching ? resource : this.cache.getCachedResource(resource);
}
@Override
public List<Double> embed(String text) {
return embed(List.of(text)).get(0);
}
@Override
public List<Double> embed(Document document) {
return this.embed(document.getFormattedContent(this.metadataMode));
}
@Override
public EmbeddingResponse embedForResponse(List<String> texts) {
List<Embedding> data = new ArrayList<>();
List<List<Double>> embed = this.embed(texts);
for (int i = 0; i < embed.size(); i++) {
data.add(new Embedding(embed.get(i), i));
}
return new EmbeddingResponse(data, Map.of());
}
@Override
public List<List<Double>> embed(List<String> texts) {
List<List<Double>> resultEmbeddings = new ArrayList<>();
try {
Encoding[] encodings = this.tokenizer.batchEncode(texts);
long[][] input_ids0 = new long[encodings.length][];
long[][] attention_mask0 = new long[encodings.length][];
long[][] token_type_ids0 = new long[encodings.length][];
for (int i = 0; i < encodings.length; i++) {
input_ids0[i] = encodings[i].getIds();
attention_mask0[i] = encodings[i].getAttentionMask();
token_type_ids0[i] = encodings[i].getTypeIds();
}
OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
"token_type_ids", tokenTypeIds);
try (OrtSession.Result results = this.session.run(modelInputs)) {
// OnnxValue lastHiddenState = results.get(0);
OnnxValue lastHiddenState = results.get("last_hidden_state").get();
// 0 - batch_size (1..x)
// 1 - sequence_length (128)
// 2 - embedding dimensions (384)
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
try (NDManager manager = NDManager.newBaseManager()) {
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
NDArray ndAttentionMask = manager.create(attention_mask0);
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
for (int i = 0; i < embedding.size(0); i++) {
resultEmbeddings.add(toDoubleList(embedding.get(i).toFloatArray()));
}
}
}
}
catch (OrtException ex) {
throw new RuntimeException(ex);
}
return resultEmbeddings;
}
// Build a NDArray from 3D float array.
private NDArray create(float[][][] data3d, NDManager manager) {
FloatBuffer buffer = FloatBuffer.allocate(data3d.length * data3d[0].length * data3d[0][0].length);
for (float[][] data2d : data3d) {
for (float[] data1d : data2d) {
buffer.put(data1d);
}
}
buffer.rewind();
return manager.create(buffer, new Shape(data3d.length, data3d[0].length, data3d[0][0].length));
}
private NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask) {
NDArray attentionMaskExpanded = attentionMask.expandDims(-1)
.broadcast(tokenEmbeddings.getShape())
.toType(DataType.FLOAT32, false);
// Multiply token embeddings with expanded attention mask
NDArray weightedEmbeddings = tokenEmbeddings.mul(attentionMaskExpanded);
// Sum along the appropriate axis
NDArray sumEmbeddings = weightedEmbeddings.sum(new int[] { EMBEDDING_AXIS });
// Clamp the attention mask sum to avoid division by zero
NDArray sumMask = attentionMaskExpanded.sum(new int[] { EMBEDDING_AXIS }).clip(1e-9f, Float.MAX_VALUE);
// Divide sum embeddings by sum mask
return sumEmbeddings.div(sumMask);
}
private List<Double> toDoubleList(float[] floats) {
List<Double> result = new ArrayList<>();
if (floats != null && floats.length > 0) {
for (int i = 0; i < floats.length; i++) {
result.add((double) floats[i]);
}
}
return result;
}
@Override
public int dimensions() {
if (this.embeddingDimensions.get() < 0) {
this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, "Test"));
}
return this.embeddingDimensions.get();
}
private static Resource toResource(String uri) {
return new DefaultResourceLoader().getResource(uri);
}
}

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e3dde332c13808c718680e7bf74a574e7e5d06f55bd6e1527e51509dcb8206f3
size 90387630

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 MiB

View File

@@ -0,0 +1,112 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.embedding;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.springframework.core.io.DefaultResourceLoader;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
*/
public class ResourceCacheServiceTests {
@TempDir
File tempDir;
@Test
public void fileResourcesAreExcludedByDefault() throws IOException {
var cache = new ResourceCacheService(tempDir);
var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
var cachedResource = cache.getCachedResource(originalResourceUri);
assertThat(cachedResource).isEqualTo(new DefaultResourceLoader().getResource(originalResourceUri));
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(0);
}
@Test
public void cacheFileResources() throws IOException {
var cache = new ResourceCacheService(tempDir);
cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names,
// including 'file'.
var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
var cachedResource1 = cache.getCachedResource(originalResourceUri);
assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri));
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1);
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1);
// Attempt to cache the same resource again should return the already cached
// resource.
var cachedResource2 = cache.getCachedResource(originalResourceUri);
assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri));
assertThat(cachedResource2).isEqualTo(cachedResource1);
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1);
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1);
}
@Test
public void cacheFileResourcesFromSameParentFolder() throws IOException {
var cache = new ResourceCacheService(tempDir);
cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names,
// including 'file'.
var originalResourceUri1 = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
var cachedResource1 = cache.getCachedResource(originalResourceUri1);
// Attempt to cache the same resource again should return the already cached
// resource.
var originalResourceUri2 = "file:src/main/resources/onnx/all-MiniLM-L6-v2/model.png";
var cachedResource2 = cache.getCachedResource(originalResourceUri2);
assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1));
assertThat(cachedResource2).isNotEqualTo(cachedResource1);
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1)
.describedAs(
"As both resources come from the same parent segments they should be cached in a single common parent.");
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(2);
}
@Test
public void cacheHttpResources() throws IOException {
var cache = new ResourceCacheService(tempDir);
var originalResourceUri1 = "https://raw.githubusercontent.com/spring-projects-experimental/spring-ai/main/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties";
var cachedResource1 = cache.getCachedResource(originalResourceUri1);
assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1));
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1);
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1);
}
}

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.embedding;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.springframework.ai.document.Document;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
*/
public class TransformersEmbeddingClientTests {
@Test
void embed() throws Exception {
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
embeddingClient.setResourceCacheDirectory("/tmp/onnx-zoo");
embeddingClient.afterPropertiesSet();
List<Double> embed = embeddingClient.embed("Hello world");
assertThat(embed).hasSize(384);
assertThat(embed.get(0)).isEqualTo(-0.19744634628295898);
assertThat(embed.get(383)).isEqualTo(0.17298996448516846);
}
@Test
void embedDocument() throws Exception {
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
embeddingClient.afterPropertiesSet();
List<Double> embed = embeddingClient.embed(new Document("Hello world"));
assertThat(embed).hasSize(384);
assertThat(embed.get(0)).isEqualTo(-0.19744634628295898);
assertThat(embed.get(383)).isEqualTo(0.17298996448516846);
}
@Test
void embedList() throws Exception {
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
embeddingClient.afterPropertiesSet();
List<List<Double>> embed = embeddingClient.embed(List.of("Hello world", "World is big"));
assertThat(embed).hasSize(2);
assertThat(embed.get(0)).hasSize(384);
assertThat(embed.get(0).get(0)).isEqualTo(-0.19744634628295898);
assertThat(embed.get(0).get(383)).isEqualTo(0.17298996448516846);
assertThat(embed.get(1)).hasSize(384);
assertThat(embed.get(1).get(0)).isEqualTo(0.4293745160102844);
assertThat(embed.get(1).get(383)).isEqualTo(0.05501303821802139);
assertThat(embed.get(0)).isNotEqualTo(embed.get(1));
}
@Test
void embedForResponse() throws Exception {
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
embeddingClient.afterPropertiesSet();
EmbeddingResponse embed = embeddingClient.embedForResponse(List.of("Hello world", "World is big"));
assertThat(embed.getData()).hasSize(2);
assertThat(embed.getMetadata()).isEmpty();
assertThat(embed.getData().get(0).getEmbedding()).hasSize(384);
assertThat(embed.getData().get(0).getEmbedding().get(0)).isEqualTo(-0.19744634628295898);
assertThat(embed.getData().get(0).getEmbedding().get(383)).isEqualTo(0.17298996448516846);
assertThat(embed.getData().get(1).getEmbedding()).hasSize(384);
assertThat(embed.getData().get(1).getEmbedding().get(0)).isEqualTo(0.4293745160102844);
assertThat(embed.getData().get(1).getEmbedding().get(383)).isEqualTo(0.05501303821802139);
}
@Test
void dimensions() throws Exception {
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
embeddingClient.afterPropertiesSet();
assertThat(embeddingClient.dimensions()).isEqualTo(384);
// cached
assertThat(embeddingClient.dimensions()).isEqualTo(384);
}
}

View File

@@ -0,0 +1,129 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.embedding.samples;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import org.springframework.core.io.DefaultResourceLoader;
// https://www.sbert.net/examples/applications/computing-embeddings/README.html#sentence-embeddings-with-transformers
public class ONNXSample {
public static NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask) {
NDArray attentionMaskExpanded = attentionMask.expandDims(-1)
.broadcast(tokenEmbeddings.getShape())
.toType(DataType.FLOAT32, false);
// Multiply token embeddings with expanded attention mask
NDArray weightedEmbeddings = tokenEmbeddings.mul(attentionMaskExpanded);
// Sum along the appropriate axis
NDArray sumEmbeddings = weightedEmbeddings.sum(new int[] { 1 });
// Clamp the attention mask sum to avoid division by zero
NDArray sumMask = attentionMaskExpanded.sum(new int[] { 1 }).clip(1e-9f, Float.MAX_VALUE);
// Divide sum embeddings by sum mask
return sumEmbeddings.div(sumMask);
}
public static void main(String[] args) throws Exception {
String TOKENIZER_URI = "classpath:/onnx/tokenizer.json";
String MODEL_URI = "classpath:/onnx/model.onnx";
var tokenizerResource = new DefaultResourceLoader().getResource(TOKENIZER_URI);
var modelResource = new DefaultResourceLoader().getResource(MODEL_URI);
String[] sentences = new String[] { "Hello world" };
// https://docs.djl.ai/extensions/tokenizers/index.html
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(tokenizerResource.getInputStream(), Map.of());
Encoding[] encodings = tokenizer.batchEncode(sentences);
long[][] input_ids0 = new long[encodings.length][];
long[][] attention_mask0 = new long[encodings.length][];
long[][] token_type_ids0 = new long[encodings.length][];
for (int i = 0; i < encodings.length; i++) {
input_ids0[i] = encodings[i].getIds();
attention_mask0[i] = encodings[i].getAttentionMask();
token_type_ids0[i] = encodings[i].getTypeIds();
}
// https://onnxruntime.ai/docs/get-started/with-java.html
OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession session = environment.createSession(modelResource.getContentAsByteArray());
OnnxTensor inputIds = OnnxTensor.createTensor(environment, input_ids0);
OnnxTensor attentionMask = OnnxTensor.createTensor(environment, attention_mask0);
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(environment, token_type_ids0);
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", inputIds);
inputs.put("attention_mask", attentionMask);
inputs.put("token_type_ids", tokenTypeIds);
try (OrtSession.Result results = session.run(inputs)) {
OnnxValue lastHiddenState = results.get(0);
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
System.out.println(tokenEmbeddings[0][0][0]);
System.out.println(tokenEmbeddings[0][1][0]);
System.out.println(tokenEmbeddings[0][2][0]);
System.out.println(tokenEmbeddings[0][3][0]);
try (NDManager manager = NDManager.newBaseManager()) {
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
NDArray ndAttentionMask = manager.create(attention_mask0);
System.out.println(ndTokenEmbeddings);
var embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
System.out.println(embedding);
}
}
}
public static NDArray create(float[][][] data, NDManager manager) {
FloatBuffer buffer = FloatBuffer.allocate(data.length * data[0].length * data[0][0].length);
for (float[][] data2 : data) {
for (float[] d : data2) {
buffer.put(d);
}
}
buffer.rewind();
return manager.create(buffer, new Shape(data.length, data[0].length, data[0][0].length));
}
}

View File

@@ -0,0 +1,43 @@
from transformers import AutoTokenizer, AutoModel
import torch
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
attention_mask1 = attention_mask.unsqueeze(-1)
attention_mask2 = attention_mask1.expand(token_embeddings.size())
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
#Sentences we want sentence embeddings for
# sentences = ['Hello world']
sentences = ['Hello world', 'World is Big']
# 'Sentences are passed as a list of string.',
# 'The quick brown fox jumps over the lazy dog.']
# sentences = ['This framework generates embeddings for each input sentence',
# 'Sentences are passed as a list of string.',
# 'The quick brown fox jumps over the lazy dog.']
#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
#Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
#Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
#Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
print(sentence_embeddings)

View File

@@ -27,6 +27,7 @@
<module>embedding-clients/spring-ai-postgresml-embedding-client</module>
<module>document-readers/pdf-reader</module>
<module>document-readers/tika-reader</module>
<module>embedding-clients/transformers-embedding</module>
</modules>
<organization>