GH-2168: Fix task type property name in Vertex AI embedding requests
Fixes: #2168 - Change property name from 'taskType' to 'task_type' in VertexAiEmbeddingUtils to match Google API expectations - Add integration tests to verify task type behavior matches Google SDK - Add missing auto truncate option copying in VertexAiTextEmbeddingOptions Signed-off-by: Soby Chacko <soby.chacko@broadcom.com>
This commit is contained in:
@@ -140,7 +140,7 @@ public abstract class VertexAiEmbeddingUtils {
|
||||
Struct.Builder textBuilder = Struct.newBuilder();
|
||||
textBuilder.putFields("content", valueOf(this.content));
|
||||
if (StringUtils.hasText(this.taskType)) {
|
||||
textBuilder.putFields("taskType", valueOf(this.taskType));
|
||||
textBuilder.putFields("task_type", valueOf(this.taskType));
|
||||
}
|
||||
if (StringUtils.hasText(this.title)) {
|
||||
textBuilder.putFields("title", valueOf(this.title));
|
||||
|
||||
@@ -187,6 +187,9 @@ public class VertexAiTextEmbeddingOptions implements EmbeddingOptions {
|
||||
if (fromOptions.getTaskType() != null) {
|
||||
this.options.setTaskType(fromOptions.getTaskType());
|
||||
}
|
||||
if (fromOptions.getAutoTruncate() != null) {
|
||||
this.options.setAutoTruncate(fromOptions.getAutoTruncate());
|
||||
}
|
||||
if (StringUtils.hasText(fromOptions.getTitle())) {
|
||||
this.options.setTitle(fromOptions.getTitle());
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -18,6 +18,13 @@ package org.springframework.ai.vertexai.embedding.text;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import com.google.cloud.aiplatform.v1.EndpointName;
|
||||
import com.google.cloud.aiplatform.v1.PredictRequest;
|
||||
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
|
||||
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
|
||||
import com.google.protobuf.Struct;
|
||||
import com.google.protobuf.Value;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
@@ -30,6 +37,7 @@ import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class)
|
||||
@@ -65,6 +73,116 @@ class VertexAiTextEmbeddingModelIT {
|
||||
assertThat(this.embeddingModel.dimensions()).isEqualTo(768);
|
||||
}
|
||||
|
||||
// Fixing https://github.com/spring-projects/spring-ai/issues/2168
|
||||
@Test
|
||||
void testTaskTypeProperty() {
|
||||
// Use text-embedding-005 model
|
||||
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
|
||||
.model("text-embedding-005")
|
||||
.taskType(VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT)
|
||||
.build();
|
||||
|
||||
String text = "Test text for embedding";
|
||||
|
||||
// Generate embedding using Spring AI with RETRIEVAL_DOCUMENT task type
|
||||
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));
|
||||
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotNull();
|
||||
|
||||
// Get the embedding result
|
||||
float[] springAiEmbedding = embeddingResponse.getResults().get(0).getOutput();
|
||||
|
||||
// Now generate the same embedding using Google SDK directly with
|
||||
// RETRIEVAL_DOCUMENT
|
||||
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");
|
||||
|
||||
// Also generate embedding using Google SDK with RETRIEVAL_QUERY (which is the
|
||||
// default)
|
||||
float[] googleSdkQueryEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_QUERY");
|
||||
|
||||
// Spring AI embedding should match with what gets generated by Google SDK with
|
||||
// RETRIEVAL_DOCUMENT task type.
|
||||
assertThat(springAiEmbedding)
|
||||
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should match Google SDK RETRIEVAL_DOCUMENT embedding")
|
||||
.isEqualTo(googleSdkDocumentEmbedding);
|
||||
|
||||
// Spring AI embedding which uses RETRIEVAL_DOCUMENT task_type should not match
|
||||
// with what gets generated by
|
||||
// Google SDK with RETRIEVAL_QUERY task type.
|
||||
assertThat(springAiEmbedding)
|
||||
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should NOT match Google SDK RETRIEVAL_QUERY embedding")
|
||||
.isNotEqualTo(googleSdkQueryEmbedding);
|
||||
}
|
||||
|
||||
// Fixing https://github.com/spring-projects/spring-ai/issues/2168
|
||||
@Test
|
||||
void testDefaultTaskTypeBehavior() {
|
||||
// Test default behavior without explicitly setting task type
|
||||
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
|
||||
.model("text-embedding-005")
|
||||
.build();
|
||||
|
||||
String text = "Test text for default embedding";
|
||||
|
||||
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));
|
||||
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
|
||||
float[] springAiDefaultEmbedding = embeddingResponse.getResults().get(0).getOutput();
|
||||
|
||||
// According to documentation, default should be RETRIEVAL_DOCUMENT
|
||||
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");
|
||||
|
||||
assertThat(springAiDefaultEmbedding)
|
||||
.as("Default Spring AI embedding should match Google SDK RETRIEVAL_DOCUMENT embedding")
|
||||
.isEqualTo(googleSdkDocumentEmbedding);
|
||||
}
|
||||
|
||||
private float[] getEmbeddingUsingGoogleSdk(String text, String taskType) {
|
||||
try {
|
||||
String endpoint = String.format("%s-aiplatform.googleapis.com:443",
|
||||
System.getenv("VERTEX_AI_GEMINI_LOCATION"));
|
||||
String project = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");
|
||||
|
||||
PredictionServiceSettings settings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
|
||||
|
||||
EndpointName endpointName = EndpointName.ofProjectLocationPublisherModelName(project,
|
||||
System.getenv("VERTEX_AI_GEMINI_LOCATION"), "google", "text-embedding-005");
|
||||
|
||||
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
|
||||
PredictRequest.Builder request = PredictRequest.newBuilder().setEndpoint(endpointName.toString());
|
||||
|
||||
request.addInstances(Value.newBuilder()
|
||||
.setStructValue(Struct.newBuilder()
|
||||
.putFields("content", Value.newBuilder().setStringValue(text).build())
|
||||
.putFields("task_type", Value.newBuilder().setStringValue(taskType).build())
|
||||
.build())
|
||||
.build());
|
||||
|
||||
var prediction = client.predict(request.build()).getPredictionsList().get(0);
|
||||
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
|
||||
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
|
||||
|
||||
List<Float> floatList = values.getListValue()
|
||||
.getValuesList()
|
||||
.stream()
|
||||
.map(Value::getNumberValue)
|
||||
.map(Double::floatValue)
|
||||
.collect(toList());
|
||||
|
||||
float[] floatArray = new float[floatList.size()];
|
||||
for (int i = 0; i < floatList.size(); i++) {
|
||||
floatArray[i] = floatList.get(i);
|
||||
}
|
||||
return floatArray;
|
||||
}
|
||||
}
|
||||
catch (Exception e) {
|
||||
throw new RuntimeException("Failed to get embedding from Google SDK", e);
|
||||
}
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
static class Config {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user