GH-2168: Fix task type property name in Vertex AI embedding requests

Fixes: #2168

- Change property name from 'taskType' to 'task_type' in VertexAiEmbeddingUtils to match Google API expectations
- Add integration tests to verify task type behavior matches Google SDK
- Add missing auto truncate option copying in VertexAiTextEmbeddingOptions

Signed-off-by: Soby Chacko <soby.chacko@broadcom.com>
This commit is contained in:
Soby Chacko
2025-05-10 16:19:41 -04:00
parent f6dba1bf08
commit 15eb24cd91
3 changed files with 123 additions and 2 deletions

View File

@@ -140,7 +140,7 @@ public abstract class VertexAiEmbeddingUtils {
Struct.Builder textBuilder = Struct.newBuilder();
textBuilder.putFields("content", valueOf(this.content));
if (StringUtils.hasText(this.taskType)) {
textBuilder.putFields("taskType", valueOf(this.taskType));
textBuilder.putFields("task_type", valueOf(this.taskType));
}
if (StringUtils.hasText(this.title)) {
textBuilder.putFields("title", valueOf(this.title));

View File

@@ -187,6 +187,9 @@ public class VertexAiTextEmbeddingOptions implements EmbeddingOptions {
if (fromOptions.getTaskType() != null) {
this.options.setTaskType(fromOptions.getTaskType());
}
if (fromOptions.getAutoTruncate() != null) {
this.options.setAutoTruncate(fromOptions.getAutoTruncate());
}
if (StringUtils.hasText(fromOptions.getTitle())) {
this.options.setTitle(fromOptions.getTitle());
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,6 +18,13 @@ package org.springframework.ai.vertexai.embedding.text;
import java.util.List;
import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
@@ -30,6 +37,7 @@ import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class)
@@ -65,6 +73,116 @@ class VertexAiTextEmbeddingModelIT {
assertThat(this.embeddingModel.dimensions()).isEqualTo(768);
}
// Fixing https://github.com/spring-projects/spring-ai/issues/2168
@Test
void testTaskTypeProperty() {
// Use text-embedding-005 model
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.model("text-embedding-005")
.taskType(VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT)
.build();
String text = "Test text for embedding";
// Generate embedding using Spring AI with RETRIEVAL_DOCUMENT task type
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotNull();
// Get the embedding result
float[] springAiEmbedding = embeddingResponse.getResults().get(0).getOutput();
// Now generate the same embedding using Google SDK directly with
// RETRIEVAL_DOCUMENT
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");
// Also generate embedding using Google SDK with RETRIEVAL_QUERY (which is the
// default)
float[] googleSdkQueryEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_QUERY");
// Spring AI embedding should match with what gets generated by Google SDK with
// RETRIEVAL_DOCUMENT task type.
assertThat(springAiEmbedding)
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should match Google SDK RETRIEVAL_DOCUMENT embedding")
.isEqualTo(googleSdkDocumentEmbedding);
// Spring AI embedding which uses RETRIEVAL_DOCUMENT task_type should not match
// with what gets generated by
// Google SDK with RETRIEVAL_QUERY task type.
assertThat(springAiEmbedding)
.as("Spring AI embedding with RETRIEVAL_DOCUMENT should NOT match Google SDK RETRIEVAL_QUERY embedding")
.isNotEqualTo(googleSdkQueryEmbedding);
}
// Fixing https://github.com/spring-projects/spring-ai/issues/2168
@Test
void testDefaultTaskTypeBehavior() {
// Test default behavior without explicitly setting task type
VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.model("text-embedding-005")
.build();
String text = "Test text for default embedding";
EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options));
assertThat(embeddingResponse.getResults()).hasSize(1);
float[] springAiDefaultEmbedding = embeddingResponse.getResults().get(0).getOutput();
// According to documentation, default should be RETRIEVAL_DOCUMENT
float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT");
assertThat(springAiDefaultEmbedding)
.as("Default Spring AI embedding should match Google SDK RETRIEVAL_DOCUMENT embedding")
.isEqualTo(googleSdkDocumentEmbedding);
}
private float[] getEmbeddingUsingGoogleSdk(String text, String taskType) {
try {
String endpoint = String.format("%s-aiplatform.googleapis.com:443",
System.getenv("VERTEX_AI_GEMINI_LOCATION"));
String project = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");
PredictionServiceSettings settings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
EndpointName endpointName = EndpointName.ofProjectLocationPublisherModelName(project,
System.getenv("VERTEX_AI_GEMINI_LOCATION"), "google", "text-embedding-005");
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
PredictRequest.Builder request = PredictRequest.newBuilder().setEndpoint(endpointName.toString());
request.addInstances(Value.newBuilder()
.setStructValue(Struct.newBuilder()
.putFields("content", Value.newBuilder().setStringValue(text).build())
.putFields("task_type", Value.newBuilder().setStringValue(taskType).build())
.build())
.build());
var prediction = client.predict(request.build()).getPredictionsList().get(0);
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
List<Float> floatList = values.getListValue()
.getValuesList()
.stream()
.map(Value::getNumberValue)
.map(Double::floatValue)
.collect(toList());
float[] floatArray = new float[floatList.size()];
for (int i = 0; i < floatList.size(); i++) {
floatArray[i] = floatList.get(i);
}
return floatArray;
}
}
catch (Exception e) {
throw new RuntimeException("Failed to get embedding from Google SDK", e);
}
}
@SpringBootConfiguration
static class Config {