Add local, Transformers EmbeddingClient
- EmbeddingClient implementation that computes, locally, sentence embeddings with SBERT transformers. - Uses pre-trained transformer models, serialized into Open Neural Network Exchange (ONNX) format. - Deep Java Library and the Microsoft ONNX Java Runtime are used to run the ONNX models and compute the embeddings efficiently. - Add default tokenizer.json and model.onnx for sentence-transformers/all-MiniLM-L6-v2. - Add, configurable resource caching service to allow caching remote (http/https) resources to the local FS. - README.md provides information on how to serialize ONNX models. - add Git LFS configuration for large onnx model files.
This commit is contained in:
committed by
Christian Tzolov
parent
e68bdeb9a0
commit
6030cda598
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
71
embedding-clients/transformers-embedding/README.md
Normal file
71
embedding-clients/transformers-embedding/README.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Local Transformers Embedding Client
|
||||
|
||||
The `TransformersEmbeddingClient` is a `EmbeddingClient` implementation that computes, locally, [sentence embeddings](https://www.sbert.net/examples/applications/computing-embeddings/README.html#sentence-embeddings-with-transformers) using a selected [sentence transformer](https://www.sbert.net/).
|
||||
|
||||
It uses [pre-trained](https://www.sbert.net/docs/pretrained_models.html) transformer models, serialized into the [Open Neural Network Exchange (ONNX)](https://onnx.ai/) format.
|
||||
|
||||
The [Deep Java Library](https://djl.ai/) and the Microsoft [ONNX Java Runtime](https://onnxruntime.ai/docs/get-started/with-java.html) libraries are applied to run the ONNX models and compute the embeddings in Java.
|
||||
|
||||
## Serialize the Tokenizer and the Transformer Model
|
||||
|
||||
To run things in Java, we need to serialize the Tokenizer and the Transformer Model into ONNX format.
|
||||
|
||||
### Serialize with optimum-cli
|
||||
|
||||
One, quick, way to achieve this, is to use the [optimum-cli](https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command line tool.
|
||||
|
||||
Following snippet creates an python virtual environment, installs the required packages and runs the optimum-cli to serialize (e.g. export) the models:
|
||||
|
||||
```bash
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
(venv) pip install --upgrade pip
|
||||
(venv) pip install optimum onnx onnxruntime
|
||||
(venv) optimum-cli export onnx --model sentence-transformers/all-MiniLM-L6-v2 onnx-output-folder
|
||||
```
|
||||
|
||||
The `optimum-cli` command exports the [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) transformer into the `onnx-output-folder` folder. Later includes the `tokenizer.json` and `model.onnx` files used by the embedding client.
|
||||
|
||||
## Apply the ONNX model
|
||||
|
||||
Use the `setTokenizerResource(tokenizerJsonUri)` and `setModelResource(modelOnnxUri)` methods to set the URI locations of the exported `tokenizer.json` and `model.onnx` files.
|
||||
The `classpath:`, `file:` or `https:` URI schemas are supported.
|
||||
|
||||
If no other model is explicitly set, the `TransformersEmbeddingClient` defaults to [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model:
|
||||
|
||||
| | |
|
||||
| -------- | ------- |
|
||||
| Dimensions |384 |
|
||||
| Avg. performance | 58.80 |
|
||||
| Speed | 14200 sentences/sec |
|
||||
| Size | 80MB |
|
||||
|
||||
|
||||
Following snippet illustrates how to use the `TransformersEmbeddingClient`:
|
||||
|
||||
```java
|
||||
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
|
||||
|
||||
// (optional) defaults to classpath:/onnx/all-MiniLM-L6-v2/tokenizer.json
|
||||
embeddingClient.setTokenizerResource("classpath:/onnx/all-MiniLM-L6-v2/tokenizer.json");
|
||||
// (optional) defaults to classpath:/onnx/all-MiniLM-L6-v2/model.onnx
|
||||
embeddingClient.setModelResource("classpath:/onnx/all-MiniLM-L6-v2/model.onnx");
|
||||
|
||||
// (optional) defaults to ${java.io.tmpdir}/spring-ai-onnx-model
|
||||
// Only the http/https resources are cached by default.
|
||||
embeddingClient.setResourceCacheDirectory("/tmp/onnx-zoo");
|
||||
|
||||
embeddingClient.afterPropertiesSet();
|
||||
|
||||
List<List<Double>> embeddings =
|
||||
embeddingClient.embed(List.of("Hello world", "World is big"));
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
86
embedding-clients/transformers-embedding/pom.xml
Normal file
86
embedding-clients/transformers-embedding/pom.xml
Normal file
@@ -0,0 +1,86 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.springframework.experimental.ai</groupId>
|
||||
<artifactId>spring-ai</artifactId>
|
||||
<version>0.7.0-SNAPSHOT</version>
|
||||
<relativePath>../../pom.xml</relativePath>
|
||||
</parent>
|
||||
<artifactId>transformers-embedding</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
<name>Spring AI Embedding Client - Sentence Transormers Embeddings </name>
|
||||
<description>Spring AI Sentence Transformers Embedding Client</description>
|
||||
<url>https://github.com/spring-projects-experimental/spring-ai</url>
|
||||
|
||||
<scm>
|
||||
<url>https://github.com/spring-projects-experimental/spring-ai</url>
|
||||
<connection>git://github.com/spring-projects-experimental/spring-ai.git</connection>
|
||||
<developerConnection>git@github.com:spring-projects-experimental/spring-ai.git</developerConnection>
|
||||
</scm>
|
||||
|
||||
<properties>
|
||||
<djl.version>0.24.0</djl.version>
|
||||
<onnxruntime.version>1.16.1</onnxruntime.version>
|
||||
</properties>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework.experimental.ai</groupId>
|
||||
<artifactId>spring-ai-core</artifactId>
|
||||
<version>${parent.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.microsoft.onnxruntime</groupId>
|
||||
<artifactId>onnxruntime</artifactId>
|
||||
<version>${onnxruntime.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ai.djl.pytorch</groupId>
|
||||
<artifactId>pytorch-engine</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>api</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ai.djl</groupId>
|
||||
<artifactId>model-zoo</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>ai.djl.huggingface</groupId>
|
||||
<artifactId>tokenizers</artifactId>
|
||||
<version>${djl.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<!-- TESTING -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-testcontainers</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.testcontainers</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@@ -0,0 +1,149 @@
|
||||
/*
|
||||
* Copyright 2023-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.embedding;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
import org.springframework.core.io.FileUrlResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.FileCopyUtils;
|
||||
import org.springframework.util.StreamUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* Service that helps caching remote {@link Resource}s on the local file system.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ResourceCacheService {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(ResourceCacheService.class);
|
||||
|
||||
/**
|
||||
* The parent folder that contains all cached resources.
|
||||
*/
|
||||
private final File cacheDirectory;
|
||||
|
||||
/**
|
||||
* Resources with URI schemas belonging to the excludedUriSchemas are not cached. By
|
||||
* default the file and classpath resources are not cached as they are already in the
|
||||
* local file system.
|
||||
*/
|
||||
private List<String> excludedUriSchemas = new ArrayList<>(List.of("file", "classpath"));
|
||||
|
||||
public ResourceCacheService() {
|
||||
this(new File(System.getProperty("java.io.tmpdir"), "spring-ai-onnx-model").getAbsolutePath());
|
||||
}
|
||||
|
||||
public ResourceCacheService(String rootCacheDirectory) {
|
||||
this(new File(rootCacheDirectory));
|
||||
}
|
||||
|
||||
public ResourceCacheService(File rootCacheDirectory) {
|
||||
Assert.notNull(rootCacheDirectory, "Cache directory can not be null.");
|
||||
this.cacheDirectory = rootCacheDirectory;
|
||||
if (!this.cacheDirectory.exists()) {
|
||||
logger.info("Create cache root directory: " + this.cacheDirectory.getAbsolutePath());
|
||||
this.cacheDirectory.mkdirs();
|
||||
}
|
||||
Assert.isTrue(this.cacheDirectory.isDirectory(), "The cache folder must be a directory");
|
||||
}
|
||||
|
||||
/**
|
||||
* Overrides the excluded URI schemas list.
|
||||
* @param excludedUriSchemas new list of URI schemas to be excluded from caching.
|
||||
*/
|
||||
public void setExcludedUriSchemas(List<String> excludedUriSchemas) {
|
||||
Assert.notNull(excludedUriSchemas, "The excluded URI schemas list can not be null");
|
||||
this.excludedUriSchemas = excludedUriSchemas;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get {@link Resource} representing the cached copy of the original resource.
|
||||
* @param originalResourceUri Resource to be cached.
|
||||
* @return Returns a cached resource. If the original resource's URI schema is within
|
||||
* the excluded schema list the original resource is returned.
|
||||
*/
|
||||
public Resource getCachedResource(String originalResourceUri) {
|
||||
return this.getCachedResource(new DefaultResourceLoader().getResource(originalResourceUri));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get {@link Resource} representing the cached copy of the original resource.
|
||||
* @param originalResource Resource to be cached.
|
||||
* @return Returns a cached resource. If the original resource's URI schema is within
|
||||
* the excluded schema list the original resource is returned.
|
||||
*/
|
||||
public Resource getCachedResource(Resource originalResource) {
|
||||
try {
|
||||
if (this.excludedUriSchemas.contains(originalResource.getURI().getScheme())) {
|
||||
logger.info("The " + originalResource.toString() + " resource with URI schema ["
|
||||
+ originalResource.getURI().getScheme() + "] is excluded from caching");
|
||||
return originalResource;
|
||||
}
|
||||
|
||||
File cachedFile = getCachedFile(originalResource);
|
||||
if (!cachedFile.exists()) {
|
||||
FileCopyUtils.copy(StreamUtils.copyToByteArray(originalResource.getInputStream()), cachedFile);
|
||||
logger.info("Caching the " + originalResource.toString() + " resource to: " + cachedFile);
|
||||
}
|
||||
return new FileUrlResource(cachedFile.getAbsolutePath());
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new IllegalStateException("Failed to cache the resource: " + originalResource.getDescription(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private File getCachedFile(Resource originalResource) throws IOException {
|
||||
var resourceParentFolder = new File(this.cacheDirectory,
|
||||
UUID.nameUUIDFromBytes(pathWithoutLastSegment(originalResource.getURI())).toString());
|
||||
resourceParentFolder.mkdirs();
|
||||
String newFileName = getCacheName(originalResource);
|
||||
return new File(resourceParentFolder, newFileName);
|
||||
}
|
||||
|
||||
private byte[] pathWithoutLastSegment(URI uri) {
|
||||
String path = uri.toASCIIString();
|
||||
var pathBeforeLastSegment = path.substring(0, path.lastIndexOf('/') + 1);
|
||||
return pathBeforeLastSegment.getBytes();
|
||||
}
|
||||
|
||||
private String getCacheName(Resource originalResource) throws IOException {
|
||||
String fileName = originalResource.getFilename();
|
||||
String fragment = originalResource.getURI().getFragment();
|
||||
return !StringUtils.hasText(fragment) ? fileName : fileName + "_" + fragment;
|
||||
}
|
||||
|
||||
public void deleteCacheFolder() {
|
||||
if (this.cacheDirectory.exists()) {
|
||||
logger.info("Empty Model Cache at:" + this.cacheDirectory.getAbsolutePath());
|
||||
this.cacheDirectory.delete();
|
||||
this.cacheDirectory.mkdirs();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
package org.springframework.ai.embedding;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import ai.djl.huggingface.tokenizers.Encoding;
|
||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OnnxValue;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* https://www.sbert.net/index.html https://www.sbert.net/docs/pretrained_models.html
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class TransformersEmbeddingClient implements EmbeddingClient, InitializingBean {
|
||||
|
||||
// ONNX tokenizer for the all-MiniLM-L6-v2 model
|
||||
private final static String DEFAULT_ONNX_TOKENIZER_URI = "https://raw.githubusercontent.com/spring-projects-experimental/spring-ai/main/embedding-clients/transformers-embedding/src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
|
||||
|
||||
// ONNX model for all-MiniLM-L6-v2 pre-trained transformer:
|
||||
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
|
||||
private final static String DEFAULT_ONNX_MODEL_URI = "https://raw.githubusercontent.com/spring-projects-experimental/spring-ai/main/embedding-clients/transformers-embedding/src/main/resources/onnx/all-MiniLM-L6-v2/model.onnx";
|
||||
|
||||
private final static int EMBEDDING_AXIS = 1;
|
||||
|
||||
private Resource tokenizerResource = toResource(DEFAULT_ONNX_TOKENIZER_URI);
|
||||
|
||||
private Resource modelResource = toResource(DEFAULT_ONNX_MODEL_URI);
|
||||
|
||||
private int gpuDeviceId = -1;
|
||||
|
||||
/**
|
||||
* DJL, Huggingface tokenizer implementation of the {@link Tokenizer} interface that
|
||||
* converts sentences into token.
|
||||
*/
|
||||
private HuggingFaceTokenizer tokenizer;
|
||||
|
||||
/**
|
||||
* ONNX runtime configurations: https://onnxruntime.ai/docs/get-started/with-java.html
|
||||
*/
|
||||
private OrtEnvironment environment;
|
||||
|
||||
private OrtSession session;
|
||||
|
||||
private final AtomicInteger embeddingDimensions = new AtomicInteger(-1);
|
||||
|
||||
private final MetadataMode metadataMode;
|
||||
|
||||
/**
|
||||
* Resource cache directory. Used to cache remote resources, such as the ONNX models,
|
||||
* to the local file system.
|
||||
*/
|
||||
private String resourceCacheDirectory;
|
||||
|
||||
/**
|
||||
* Allow disabling the resource caching.
|
||||
*/
|
||||
private boolean disableCaching = false;
|
||||
|
||||
private ResourceCacheService cache;
|
||||
|
||||
public TransformersEmbeddingClient() {
|
||||
this(MetadataMode.NONE);
|
||||
}
|
||||
|
||||
public TransformersEmbeddingClient(MetadataMode metadataMode) {
|
||||
Assert.notNull(metadataMode, "Metadata mode should not be null");
|
||||
this.metadataMode = metadataMode;
|
||||
}
|
||||
|
||||
public void setDisableCaching(boolean disableCaching) {
|
||||
this.disableCaching = disableCaching;
|
||||
}
|
||||
|
||||
public void setResourceCacheDirectory(String resourceCacheDir) {
|
||||
this.resourceCacheDirectory = resourceCacheDir;
|
||||
}
|
||||
|
||||
public void setGpuDeviceId(int gpuDeviceId) {
|
||||
this.gpuDeviceId = gpuDeviceId;
|
||||
}
|
||||
|
||||
public void setTokenizerResource(Resource tokenizerResource) {
|
||||
this.tokenizerResource = tokenizerResource;
|
||||
}
|
||||
|
||||
public void setModelResource(Resource modelResource) {
|
||||
this.modelResource = modelResource;
|
||||
}
|
||||
|
||||
public void setTokenizerResource(String tokenizerResourceUri) {
|
||||
this.tokenizerResource = toResource(tokenizerResourceUri);
|
||||
}
|
||||
|
||||
public void setModelResource(String modelResourceUri) {
|
||||
this.modelResource = toResource(modelResourceUri);
|
||||
}
|
||||
|
||||
public void setEmbeddingDimensions(int dimension) {
|
||||
this.embeddingDimensions.set(dimension);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() throws Exception {
|
||||
this.cache = StringUtils.hasText(this.resourceCacheDirectory)
|
||||
? new ResourceCacheService(this.resourceCacheDirectory) : new ResourceCacheService();
|
||||
this.tokenizer = HuggingFaceTokenizer.newInstance(getCachedResource(this.tokenizerResource).getInputStream(),
|
||||
Map.of());
|
||||
this.environment = OrtEnvironment.getEnvironment();
|
||||
|
||||
var sessionOptions = new OrtSession.SessionOptions();
|
||||
if (this.gpuDeviceId >= 0) {
|
||||
// Run on a GPU or with another provider
|
||||
sessionOptions.addCUDA(this.gpuDeviceId);
|
||||
}
|
||||
this.session = this.environment.createSession(getCachedResource(this.modelResource).getContentAsByteArray(),
|
||||
sessionOptions);
|
||||
}
|
||||
|
||||
private Resource getCachedResource(Resource resource) {
|
||||
return this.disableCaching ? resource : this.cache.getCachedResource(resource);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> embed(String text) {
|
||||
return embed(List.of(text)).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Double> embed(Document document) {
|
||||
return this.embed(document.getFormattedContent(this.metadataMode));
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingResponse embedForResponse(List<String> texts) {
|
||||
List<Embedding> data = new ArrayList<>();
|
||||
List<List<Double>> embed = this.embed(texts);
|
||||
for (int i = 0; i < embed.size(); i++) {
|
||||
data.add(new Embedding(embed.get(i), i));
|
||||
}
|
||||
return new EmbeddingResponse(data, Map.of());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<List<Double>> embed(List<String> texts) {
|
||||
|
||||
List<List<Double>> resultEmbeddings = new ArrayList<>();
|
||||
|
||||
try {
|
||||
|
||||
Encoding[] encodings = this.tokenizer.batchEncode(texts);
|
||||
|
||||
long[][] input_ids0 = new long[encodings.length][];
|
||||
long[][] attention_mask0 = new long[encodings.length][];
|
||||
long[][] token_type_ids0 = new long[encodings.length][];
|
||||
|
||||
for (int i = 0; i < encodings.length; i++) {
|
||||
input_ids0[i] = encodings[i].getIds();
|
||||
attention_mask0[i] = encodings[i].getAttentionMask();
|
||||
token_type_ids0[i] = encodings[i].getTypeIds();
|
||||
}
|
||||
|
||||
OnnxTensor inputIds = OnnxTensor.createTensor(this.environment, input_ids0);
|
||||
OnnxTensor attentionMask = OnnxTensor.createTensor(this.environment, attention_mask0);
|
||||
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(this.environment, token_type_ids0);
|
||||
|
||||
Map<String, OnnxTensor> modelInputs = Map.of("input_ids", inputIds, "attention_mask", attentionMask,
|
||||
"token_type_ids", tokenTypeIds);
|
||||
|
||||
try (OrtSession.Result results = this.session.run(modelInputs)) {
|
||||
|
||||
// OnnxValue lastHiddenState = results.get(0);
|
||||
OnnxValue lastHiddenState = results.get("last_hidden_state").get();
|
||||
|
||||
// 0 - batch_size (1..x)
|
||||
// 1 - sequence_length (128)
|
||||
// 2 - embedding dimensions (384)
|
||||
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
|
||||
|
||||
try (NDManager manager = NDManager.newBaseManager()) {
|
||||
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
|
||||
NDArray ndAttentionMask = manager.create(attention_mask0);
|
||||
|
||||
NDArray embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
|
||||
|
||||
for (int i = 0; i < embedding.size(0); i++) {
|
||||
resultEmbeddings.add(toDoubleList(embedding.get(i).toFloatArray()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (OrtException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
|
||||
return resultEmbeddings;
|
||||
}
|
||||
|
||||
// Build a NDArray from 3D float array.
|
||||
private NDArray create(float[][][] data3d, NDManager manager) {
|
||||
|
||||
FloatBuffer buffer = FloatBuffer.allocate(data3d.length * data3d[0].length * data3d[0][0].length);
|
||||
|
||||
for (float[][] data2d : data3d) {
|
||||
for (float[] data1d : data2d) {
|
||||
buffer.put(data1d);
|
||||
}
|
||||
}
|
||||
buffer.rewind();
|
||||
|
||||
return manager.create(buffer, new Shape(data3d.length, data3d[0].length, data3d[0][0].length));
|
||||
}
|
||||
|
||||
private NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask) {
|
||||
|
||||
NDArray attentionMaskExpanded = attentionMask.expandDims(-1)
|
||||
.broadcast(tokenEmbeddings.getShape())
|
||||
.toType(DataType.FLOAT32, false);
|
||||
|
||||
// Multiply token embeddings with expanded attention mask
|
||||
NDArray weightedEmbeddings = tokenEmbeddings.mul(attentionMaskExpanded);
|
||||
|
||||
// Sum along the appropriate axis
|
||||
NDArray sumEmbeddings = weightedEmbeddings.sum(new int[] { EMBEDDING_AXIS });
|
||||
|
||||
// Clamp the attention mask sum to avoid division by zero
|
||||
NDArray sumMask = attentionMaskExpanded.sum(new int[] { EMBEDDING_AXIS }).clip(1e-9f, Float.MAX_VALUE);
|
||||
|
||||
// Divide sum embeddings by sum mask
|
||||
return sumEmbeddings.div(sumMask);
|
||||
}
|
||||
|
||||
private List<Double> toDoubleList(float[] floats) {
|
||||
List<Double> result = new ArrayList<>();
|
||||
if (floats != null && floats.length > 0) {
|
||||
for (int i = 0; i < floats.length; i++) {
|
||||
result.add((double) floats[i]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimensions() {
|
||||
if (this.embeddingDimensions.get() < 0) {
|
||||
this.embeddingDimensions.set(EmbeddingUtil.dimensions(this, "Test"));
|
||||
}
|
||||
return this.embeddingDimensions.get();
|
||||
}
|
||||
|
||||
private static Resource toResource(String uri) {
|
||||
return new DefaultResourceLoader().getResource(uri);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e3dde332c13808c718680e7bf74a574e7e5d06f55bd6e1527e51509dcb8206f3
|
||||
size 90387630
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 4.4 MiB |
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright 2023-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.embedding;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class ResourceCacheServiceTests {
|
||||
|
||||
@TempDir
|
||||
File tempDir;
|
||||
|
||||
@Test
|
||||
public void fileResourcesAreExcludedByDefault() throws IOException {
|
||||
var cache = new ResourceCacheService(tempDir);
|
||||
var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
|
||||
var cachedResource = cache.getCachedResource(originalResourceUri);
|
||||
|
||||
assertThat(cachedResource).isEqualTo(new DefaultResourceLoader().getResource(originalResourceUri));
|
||||
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void cacheFileResources() throws IOException {
|
||||
var cache = new ResourceCacheService(tempDir);
|
||||
|
||||
cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names,
|
||||
// including 'file'.
|
||||
|
||||
var originalResourceUri = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
|
||||
var cachedResource1 = cache.getCachedResource(originalResourceUri);
|
||||
|
||||
assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri));
|
||||
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1);
|
||||
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1);
|
||||
|
||||
// Attempt to cache the same resource again should return the already cached
|
||||
// resource.
|
||||
var cachedResource2 = cache.getCachedResource(originalResourceUri);
|
||||
|
||||
assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri));
|
||||
assertThat(cachedResource2).isEqualTo(cachedResource1);
|
||||
|
||||
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1);
|
||||
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void cacheFileResourcesFromSameParentFolder() throws IOException {
|
||||
var cache = new ResourceCacheService(tempDir);
|
||||
|
||||
cache.setExcludedUriSchemas(List.of()); // erase the excluded schema names,
|
||||
// including 'file'.
|
||||
|
||||
var originalResourceUri1 = "file:src/main/resources/onnx/all-MiniLM-L6-v2/tokenizer.json";
|
||||
var cachedResource1 = cache.getCachedResource(originalResourceUri1);
|
||||
|
||||
// Attempt to cache the same resource again should return the already cached
|
||||
// resource.
|
||||
var originalResourceUri2 = "file:src/main/resources/onnx/all-MiniLM-L6-v2/model.png";
|
||||
var cachedResource2 = cache.getCachedResource(originalResourceUri2);
|
||||
|
||||
assertThat(cachedResource2).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1));
|
||||
assertThat(cachedResource2).isNotEqualTo(cachedResource1);
|
||||
|
||||
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1)
|
||||
.describedAs(
|
||||
"As both resources come from the same parent segments they should be cached in a single common parent.");
|
||||
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(2);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void cacheHttpResources() throws IOException {
|
||||
var cache = new ResourceCacheService(tempDir);
|
||||
|
||||
var originalResourceUri1 = "https://raw.githubusercontent.com/spring-projects-experimental/spring-ai/main/spring-ai-core/src/main/resources/embedding/embedding-model-dimensions.properties";
|
||||
var cachedResource1 = cache.getCachedResource(originalResourceUri1);
|
||||
|
||||
assertThat(cachedResource1).isNotEqualTo(new DefaultResourceLoader().getResource(originalResourceUri1));
|
||||
assertThat(Files.list(tempDir.toPath()).count()).isEqualTo(1);
|
||||
assertThat(Files.list(Files.list(tempDir.toPath()).iterator().next()).count()).isEqualTo(1);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright 2023-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.embedding;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class TransformersEmbeddingClientTests {
|
||||
|
||||
@Test
|
||||
void embed() throws Exception {
|
||||
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
|
||||
embeddingClient.setResourceCacheDirectory("/tmp/onnx-zoo");
|
||||
embeddingClient.afterPropertiesSet();
|
||||
List<Double> embed = embeddingClient.embed("Hello world");
|
||||
assertThat(embed).hasSize(384);
|
||||
assertThat(embed.get(0)).isEqualTo(-0.19744634628295898);
|
||||
assertThat(embed.get(383)).isEqualTo(0.17298996448516846);
|
||||
}
|
||||
|
||||
@Test
|
||||
void embedDocument() throws Exception {
|
||||
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
|
||||
embeddingClient.afterPropertiesSet();
|
||||
List<Double> embed = embeddingClient.embed(new Document("Hello world"));
|
||||
assertThat(embed).hasSize(384);
|
||||
assertThat(embed.get(0)).isEqualTo(-0.19744634628295898);
|
||||
assertThat(embed.get(383)).isEqualTo(0.17298996448516846);
|
||||
}
|
||||
|
||||
@Test
|
||||
void embedList() throws Exception {
|
||||
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
|
||||
embeddingClient.afterPropertiesSet();
|
||||
List<List<Double>> embed = embeddingClient.embed(List.of("Hello world", "World is big"));
|
||||
assertThat(embed).hasSize(2);
|
||||
assertThat(embed.get(0)).hasSize(384);
|
||||
assertThat(embed.get(0).get(0)).isEqualTo(-0.19744634628295898);
|
||||
assertThat(embed.get(0).get(383)).isEqualTo(0.17298996448516846);
|
||||
|
||||
assertThat(embed.get(1)).hasSize(384);
|
||||
assertThat(embed.get(1).get(0)).isEqualTo(0.4293745160102844);
|
||||
assertThat(embed.get(1).get(383)).isEqualTo(0.05501303821802139);
|
||||
|
||||
assertThat(embed.get(0)).isNotEqualTo(embed.get(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void embedForResponse() throws Exception {
|
||||
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
|
||||
embeddingClient.afterPropertiesSet();
|
||||
EmbeddingResponse embed = embeddingClient.embedForResponse(List.of("Hello world", "World is big"));
|
||||
assertThat(embed.getData()).hasSize(2);
|
||||
assertThat(embed.getMetadata()).isEmpty();
|
||||
|
||||
assertThat(embed.getData().get(0).getEmbedding()).hasSize(384);
|
||||
assertThat(embed.getData().get(0).getEmbedding().get(0)).isEqualTo(-0.19744634628295898);
|
||||
assertThat(embed.getData().get(0).getEmbedding().get(383)).isEqualTo(0.17298996448516846);
|
||||
|
||||
assertThat(embed.getData().get(1).getEmbedding()).hasSize(384);
|
||||
assertThat(embed.getData().get(1).getEmbedding().get(0)).isEqualTo(0.4293745160102844);
|
||||
assertThat(embed.getData().get(1).getEmbedding().get(383)).isEqualTo(0.05501303821802139);
|
||||
}
|
||||
|
||||
@Test
|
||||
void dimensions() throws Exception {
|
||||
|
||||
TransformersEmbeddingClient embeddingClient = new TransformersEmbeddingClient();
|
||||
embeddingClient.afterPropertiesSet();
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(384);
|
||||
// cached
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(384);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
/*
|
||||
* Copyright 2023-2023 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.embedding.samples;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import ai.djl.huggingface.tokenizers.Encoding;
|
||||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
|
||||
import ai.djl.ndarray.NDArray;
|
||||
import ai.djl.ndarray.NDManager;
|
||||
import ai.djl.ndarray.types.DataType;
|
||||
import ai.djl.ndarray.types.Shape;
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OnnxValue;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
|
||||
import org.springframework.core.io.DefaultResourceLoader;
|
||||
|
||||
// https://www.sbert.net/examples/applications/computing-embeddings/README.html#sentence-embeddings-with-transformers
|
||||
|
||||
public class ONNXSample {
|
||||
|
||||
public static NDArray meanPooling(NDArray tokenEmbeddings, NDArray attentionMask) {
|
||||
|
||||
NDArray attentionMaskExpanded = attentionMask.expandDims(-1)
|
||||
.broadcast(tokenEmbeddings.getShape())
|
||||
.toType(DataType.FLOAT32, false);
|
||||
|
||||
// Multiply token embeddings with expanded attention mask
|
||||
NDArray weightedEmbeddings = tokenEmbeddings.mul(attentionMaskExpanded);
|
||||
|
||||
// Sum along the appropriate axis
|
||||
NDArray sumEmbeddings = weightedEmbeddings.sum(new int[] { 1 });
|
||||
|
||||
// Clamp the attention mask sum to avoid division by zero
|
||||
NDArray sumMask = attentionMaskExpanded.sum(new int[] { 1 }).clip(1e-9f, Float.MAX_VALUE);
|
||||
|
||||
// Divide sum embeddings by sum mask
|
||||
return sumEmbeddings.div(sumMask);
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
String TOKENIZER_URI = "classpath:/onnx/tokenizer.json";
|
||||
String MODEL_URI = "classpath:/onnx/model.onnx";
|
||||
|
||||
var tokenizerResource = new DefaultResourceLoader().getResource(TOKENIZER_URI);
|
||||
var modelResource = new DefaultResourceLoader().getResource(MODEL_URI);
|
||||
|
||||
String[] sentences = new String[] { "Hello world" };
|
||||
|
||||
// https://docs.djl.ai/extensions/tokenizers/index.html
|
||||
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(tokenizerResource.getInputStream(), Map.of());
|
||||
Encoding[] encodings = tokenizer.batchEncode(sentences);
|
||||
|
||||
long[][] input_ids0 = new long[encodings.length][];
|
||||
long[][] attention_mask0 = new long[encodings.length][];
|
||||
long[][] token_type_ids0 = new long[encodings.length][];
|
||||
|
||||
for (int i = 0; i < encodings.length; i++) {
|
||||
input_ids0[i] = encodings[i].getIds();
|
||||
attention_mask0[i] = encodings[i].getAttentionMask();
|
||||
token_type_ids0[i] = encodings[i].getTypeIds();
|
||||
}
|
||||
|
||||
// https://onnxruntime.ai/docs/get-started/with-java.html
|
||||
OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession session = environment.createSession(modelResource.getContentAsByteArray());
|
||||
|
||||
OnnxTensor inputIds = OnnxTensor.createTensor(environment, input_ids0);
|
||||
OnnxTensor attentionMask = OnnxTensor.createTensor(environment, attention_mask0);
|
||||
OnnxTensor tokenTypeIds = OnnxTensor.createTensor(environment, token_type_ids0);
|
||||
|
||||
Map<String, OnnxTensor> inputs = new HashMap<>();
|
||||
inputs.put("input_ids", inputIds);
|
||||
inputs.put("attention_mask", attentionMask);
|
||||
inputs.put("token_type_ids", tokenTypeIds);
|
||||
|
||||
try (OrtSession.Result results = session.run(inputs)) {
|
||||
|
||||
OnnxValue lastHiddenState = results.get(0);
|
||||
|
||||
float[][][] tokenEmbeddings = (float[][][]) lastHiddenState.getValue();
|
||||
|
||||
System.out.println(tokenEmbeddings[0][0][0]);
|
||||
System.out.println(tokenEmbeddings[0][1][0]);
|
||||
System.out.println(tokenEmbeddings[0][2][0]);
|
||||
System.out.println(tokenEmbeddings[0][3][0]);
|
||||
|
||||
try (NDManager manager = NDManager.newBaseManager()) {
|
||||
NDArray ndTokenEmbeddings = create(tokenEmbeddings, manager);
|
||||
NDArray ndAttentionMask = manager.create(attention_mask0);
|
||||
System.out.println(ndTokenEmbeddings);
|
||||
|
||||
var embedding = meanPooling(ndTokenEmbeddings, ndAttentionMask);
|
||||
System.out.println(embedding);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public static NDArray create(float[][][] data, NDManager manager) {
|
||||
FloatBuffer buffer = FloatBuffer.allocate(data.length * data[0].length * data[0][0].length);
|
||||
for (float[][] data2 : data) {
|
||||
for (float[] d : data2) {
|
||||
buffer.put(d);
|
||||
}
|
||||
}
|
||||
buffer.rewind();
|
||||
return manager.create(buffer, new Shape(data.length, data[0].length, data[0][0].length));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
|
||||
|
||||
#Mean Pooling - Take attention mask into account for correct averaging
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||||
|
||||
attention_mask1 = attention_mask.unsqueeze(-1)
|
||||
attention_mask2 = attention_mask1.expand(token_embeddings.size())
|
||||
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
||||
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
return sum_embeddings / sum_mask
|
||||
|
||||
|
||||
|
||||
#Sentences we want sentence embeddings for
|
||||
# sentences = ['Hello world']
|
||||
sentences = ['Hello world', 'World is Big']
|
||||
# 'Sentences are passed as a list of string.',
|
||||
# 'The quick brown fox jumps over the lazy dog.']
|
||||
# sentences = ['This framework generates embeddings for each input sentence',
|
||||
# 'Sentences are passed as a list of string.',
|
||||
# 'The quick brown fox jumps over the lazy dog.']
|
||||
|
||||
#Load AutoModel from huggingface model repository
|
||||
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
||||
|
||||
#Tokenize sentences
|
||||
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
|
||||
|
||||
#Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = model(**encoded_input)
|
||||
|
||||
#Perform pooling. In this case, mean pooling
|
||||
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
||||
|
||||
print(sentence_embeddings)
|
||||
|
||||
Reference in New Issue
Block a user