Update to ResponseMetadata design

* Remove inheritance from HashMap
* No more subclasses per model provider
* Builder class for ChatResponse
This commit is contained in:
Mark Pollack
2024-07-15 17:45:08 -04:00
parent 17c44237a5
commit a0b3d12a27
39 changed files with 767 additions and 731 deletions

View File

@@ -15,14 +15,6 @@
*/
package org.springframework.ai.anthropic;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.AnthropicApi;
@@ -32,12 +24,13 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -52,10 +45,17 @@ import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* The {@link ChatModel} implementation for the Anthropic service.
*
@@ -228,7 +228,20 @@ public class AnthropicChatModel extends AbstractToolCallSupport<ChatCompletionRe
.withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
}).toList();
return new ChatResponse(generations, AnthropicChatResponseMetadata.from(chatCompletion));
return new ChatResponse(generations, from(chatCompletion));
}
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
AnthropicUsage usage = AnthropicUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("stop-reason", result.stopReason())
.withKeyValue("stop-sequence", result.stopSequence())
.withKeyValue("type", result.type())
.build();
}
private String fromMediaData(Object mediaData) {

View File

@@ -1,103 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.anthropic.metadata;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* {@link ChatResponseMetadata} implementation for {@literal AnthropicApi}.
*
* @author Christian Tzolov
* @author Thomas Vitale
* @see ChatResponseMetadata
* @see RateLimit
* @see Usage
* @since 1.0.0
*/
public class AnthropicChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }";
public static AnthropicChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
AnthropicUsage usage = AnthropicUsage.from(result.usage());
return new AnthropicChatResponseMetadata(result.id(), result.model(), usage);
}
private final String id;
private final String model;
@Nullable
private RateLimit rateLimit;
private final Usage usage;
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage) {
this(id, model, usage, null);
}
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage,
@Nullable AnthropicRateLimit rateLimit) {
this.id = id;
this.model = model;
this.usage = usage;
this.rateLimit = rateLimit;
}
@Override
public String getId() {
return this.id;
}
@Override
public String getModel() {
return this.model;
}
@Override
@Nullable
public RateLimit getRateLimit() {
RateLimit rl = this.rateLimit;
return rl != null ? rl : new EmptyRateLimit();
}
@Override
public Usage getUsage() {
Usage usage = this.usage;
return usage != null ? usage : new EmptyUsage();
}
public AnthropicChatResponseMetadata withRateLimit(RateLimit rateLimit) {
this.rateLimit = rateLimit;
return this;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit());
}
}

View File

@@ -15,33 +15,6 @@
*/
package org.springframework.ai.azure.openai;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.ChatChoice;
import com.azure.ai.openai.models.ChatCompletions;
@@ -68,10 +41,36 @@ import com.azure.ai.openai.models.FunctionCall;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
* {@link OpenAIClient}.
@@ -151,8 +150,22 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport<ChatCompletion
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
return new ChatResponse(generations,
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
}
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
String id = chatCompletions.getId();
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
.withId(id)
.withUsage(usage)
.withModel(chatCompletions.getModel())
.withPromptMetadata(promptFilterMetadata)
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
.build();
return chatResponseMetadata;
}
@Override

View File

@@ -1,83 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.azure.openai.metadata;
import com.azure.ai.openai.models.ChatCompletions;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* {@link ChatResponseMetadata} implementation for
* {@literal Microsoft Azure OpenAI Service}.
*
* @author John Blum
* @author Thomas Vitale
* @see ChatResponseMetadata
* @since 0.7.1
*/
public class AzureOpenAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";
@SuppressWarnings("all")
public static AzureOpenAiChatResponseMetadata from(ChatCompletions chatCompletions,
PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
String id = chatCompletions.getId();
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
AzureOpenAiChatResponseMetadata chatResponseMetadata = new AzureOpenAiChatResponseMetadata(id, usage,
promptFilterMetadata);
return chatResponseMetadata;
}
private final String id;
private final Usage usage;
private final PromptMetadata promptMetadata;
protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, PromptMetadata promptMetadata) {
this.id = id;
this.usage = usage;
this.promptMetadata = promptMetadata;
}
@Override
public String getId() {
return this.id;
}
@Override
public Usage getUsage() {
return this.usage;
}
@Override
public PromptMetadata getPromptMetadata() {
return this.promptMetadata;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit());
}
}

View File

@@ -2,6 +2,7 @@ package org.springframework.ai.azure.openai.metadata;
import com.azure.ai.openai.models.ImageGenerations;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.model.MutableResponseMetadata;
import org.springframework.util.Assert;
import java.util.HashMap;
@@ -15,7 +16,7 @@ import java.util.Objects;
* @author Benoit Moussaud
* @since 1.0.0 M1
*/
public class AzureOpenAiImageResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
public class AzureOpenAiImageResponseMetadata extends ImageResponseMetadata {
private final Long created;

View File

@@ -15,14 +15,6 @@
*/
package org.springframework.ai.mistralai;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -31,6 +23,7 @@ import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -44,7 +37,7 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
import org.springframework.ai.mistralai.metadata.MistralAiChatResponseMetadata;
import org.springframework.ai.mistralai.metadata.MistralAiUsage;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -53,10 +46,17 @@ import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author Ricken Bazolo
* @author Christian Tzolov
@@ -134,10 +134,21 @@ public class MistralAiChatModel extends AbstractToolCallSupport<ChatCompletion>
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
.toList();
return new ChatResponse(generations, MistralAiChatResponseMetadata.from(chatCompletion));
return new ChatResponse(generations, from(chatCompletion));
});
}
public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
MistralAiUsage usage = MistralAiUsage.from(result.usage());
return ChatResponseMetadata.builder()
.withId(result.id())
.withModel(result.model())
.withUsage(usage)
.withKeyValue("created", result.created())
.build();
}
private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
Map<String, Object> map = new HashMap<>();

View File

@@ -1,62 +0,0 @@
package org.springframework.ai.mistralai.metadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* {@link ChatResponseMetadata} implementation for {@literal Mistral AI}.
*
* @author Thomas Vitale
* @see ChatResponseMetadata
* @see Usage
* @since 1.0.0
*/
public class MistralAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s }";
public static MistralAiChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
MistralAiUsage usage = MistralAiUsage.from(result.usage());
return new MistralAiChatResponseMetadata(result.id(), result.model(), usage);
}
private final String id;
private final String model;
private final Usage usage;
protected MistralAiChatResponseMetadata(String id, String model, MistralAiUsage usage) {
this.id = id;
this.model = model;
this.usage = usage;
}
@Override
public String getId() {
return this.id;
}
@Override
public String getModel() {
return this.model;
}
@Override
public Usage getUsage() {
Usage usage = this.usage;
return usage != null ? usage : new EmptyUsage();
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage());
}
}

View File

@@ -15,27 +15,27 @@
*/
package org.springframework.ai.ollama;
import java.util.Base64;
import java.util.List;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.ollama.metadata.OllamaChatResponseMetadata;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.metadata.OllamaUsage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import java.util.Base64;
import java.util.List;
/**
* {@link ChatModel} implementation for {@literal Ollama}.
@@ -102,7 +102,23 @@ public class OllamaChatModel implements ChatModel {
if (response.promptEvalCount() != null && response.evalCount() != null) {
generator = generator.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
}
return new ChatResponse(List.of(generator), OllamaChatResponseMetadata.from(response));
return new ChatResponse(List.of(generator), from(response));
}
public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
return ChatResponseMetadata.builder()
.withUsage(OllamaUsage.from(response))
.withModel(response.model())
.withKeyValue("created-at", response.createdAt())
.withKeyValue("eval-duration", response.evalDuration())
.withKeyValue("eval-count", response.evalCount())
.withKeyValue("load-duration", response.loadDuration())
.withKeyValue("eval-duration", response.promptEvalDuration())
.withKeyValue("eval-count", response.promptEvalCount())
.withKeyValue("total-duration", response.totalDuration())
.withKeyValue("done", response.done())
.build();
}
@Override
@@ -116,7 +132,7 @@ public class OllamaChatModel implements ChatModel {
if (Boolean.TRUE.equals(chunk.done())) {
generation = generation.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
}
return new ChatResponse(List.of(generation), OllamaChatResponseMetadata.from(chunk));
return new ChatResponse(List.of(generation), from(chunk));
});
}

View File

@@ -1,57 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.ollama.metadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* {@link ChatResponseMetadata} implementation for {@literal Ollama}
*
* @see ChatResponseMetadata
* @author Fu Cheng
*/
public class OllamaChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, usage: %2$s, rateLimit: %3$s }";
public static OllamaChatResponseMetadata from(OllamaApi.ChatResponse response) {
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
Usage usage = OllamaUsage.from(response);
return new OllamaChatResponseMetadata(usage);
}
private final Usage usage;
protected OllamaChatResponseMetadata(Usage usage) {
this.usage = usage;
}
@Override
public Usage getUsage() {
return this.usage;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getUsage(), getRateLimit());
}
}

View File

@@ -1,5 +1,5 @@
package org.springframework.ai.openai;
public class ImageResponseMetadata {
public interface ImageResponseMetadata {
}

View File

@@ -15,14 +15,6 @@
*/
package org.springframework.ai.openai;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -30,6 +22,7 @@ import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
@@ -49,7 +42,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCom
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.openai.metadata.OpenAiUsage;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
@@ -57,10 +50,17 @@ import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
* backed by {@link OpenAiApi}.
@@ -165,7 +165,7 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
}
// Non function calling.
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
List<Choice> choices = chatCompletion.choices();
if (choices == null) {
@@ -186,11 +186,22 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
}).toList();
return new ChatResponse(generations,
OpenAiChatResponseMetadata.from(completionEntity.getBody()).withRateLimit(rateLimits));
return new ChatResponse(generations, from(completionEntity.getBody(), rateLimit));
});
}
public static ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit) {
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.withId(result.id())
.withUsage(OpenAiUsage.from(result.usage()))
.withModel(result.model())
.withRateLimit(rateLimit)
.withKeyValue("created", result.created())
.withKeyValue("system-fingerprint", result.systemFingerprint())
.build();
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
@@ -237,7 +248,7 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
}).toList();
if (chatCompletion2.usage() != null) {
return new ChatResponse(generations, OpenAiChatResponseMetadata.from(chatCompletion2));
return new ChatResponse(generations, from(chatCompletion2));
}
else {
return new ChatResponse(generations);
@@ -253,6 +264,17 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
});
}
private ChatResponseMetadata from(OpenAiApi.ChatCompletion result) {
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.withId(result.id())
.withUsage(OpenAiUsage.from(result.usage()))
.withModel(result.model())
.withKeyValue("created", result.created())
.withKeyValue("system-fingerprint", result.systemFingerprint())
.build();
}
private List<Message> handleToolCallRequests(List<Message> previousMessages, ChatCompletion chatCompletion) {
ChatCompletionMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion);

View File

@@ -15,8 +15,6 @@
*/
package org.springframework.ai.openai;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.Image;
@@ -29,12 +27,13 @@ import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import java.util.List;
/**
* OpenAiImageModel is a class that implements the ImageModel interface. It provides a
* client for calling the OpenAI image generation API.
@@ -130,7 +129,7 @@ public class OpenAiImageModel implements ImageModel {
new OpenAiImageGenerationMetadata(entry.revisedPrompt()));
}).toList();
ImageResponseMetadata openAiImageResponseMetadata = OpenAiImageResponseMetadata.from(imageApiResponse);
ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
}

View File

@@ -1,103 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.metadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* {@link ChatResponseMetadata} implementation for {@literal OpenAI}.
*
* @author John Blum
* @author Thomas Vitale
* @see ChatResponseMetadata
* @see RateLimit
* @see Usage
* @since 0.7.0
*/
public class OpenAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }";
public static OpenAiChatResponseMetadata from(OpenAiApi.ChatCompletion result) {
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
OpenAiUsage usage = OpenAiUsage.from(result.usage());
return new OpenAiChatResponseMetadata(result.id(), result.model(), usage);
}
private final String id;
private final String model;
@Nullable
private RateLimit rateLimit;
private final Usage usage;
protected OpenAiChatResponseMetadata(String id, String model, OpenAiUsage usage) {
this(id, model, usage, null);
}
protected OpenAiChatResponseMetadata(String id, String model, OpenAiUsage usage,
@Nullable OpenAiRateLimit rateLimit) {
this.id = id;
this.model = model;
this.usage = usage;
this.rateLimit = rateLimit;
}
@Override
public String getId() {
return this.id;
}
@Override
public String getModel() {
return this.model;
}
@Override
@Nullable
public RateLimit getRateLimit() {
RateLimit rateLimit = this.rateLimit;
return rateLimit != null ? rateLimit : new EmptyRateLimit();
}
@Override
public Usage getUsage() {
Usage usage = this.usage;
return usage != null ? usage : new EmptyUsage();
}
public OpenAiChatResponseMetadata withRateLimit(RateLimit rateLimit) {
this.rateLimit = rateLimit;
return this;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit());
}
}

View File

@@ -1,62 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.metadata;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.util.Assert;
import java.util.HashMap;
import java.util.Objects;
public class OpenAiImageResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
private final Long created;
public static OpenAiImageResponseMetadata from(OpenAiImageApi.OpenAiImageResponse openAiImageResponse) {
Assert.notNull(openAiImageResponse, "OpenAiImageResponse must not be null");
return new OpenAiImageResponseMetadata(openAiImageResponse.created());
}
protected OpenAiImageResponseMetadata(Long created) {
this.created = created;
}
@Override
public Long getCreated() {
return this.created;
}
@Override
public String toString() {
return "OpenAiImageResponseMetadata{" + "created=" + created + '}';
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof OpenAiImageResponseMetadata that))
return false;
return Objects.equals(created, that.created);
}
@Override
public int hashCode() {
return Objects.hash(created);
}
}

View File

@@ -18,6 +18,7 @@ package org.springframework.ai.openai.metadata.audio;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.MutableResponseMetadata;
import org.springframework.ai.model.ResponseMetadata;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.lang.Nullable;
@@ -31,7 +32,7 @@ import java.util.HashMap;
* @author Ahmed Yousri
* @see RateLimit
*/
public class OpenAiAudioSpeechResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
public class OpenAiAudioSpeechResponseMetadata extends MutableResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }";

View File

@@ -17,6 +17,7 @@ package org.springframework.ai.openai.metadata.audio;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.MutableResponseMetadata;
import org.springframework.ai.model.ResponseMetadata;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.metadata.OpenAiRateLimit;
@@ -32,7 +33,7 @@ import java.util.HashMap;
* @since 0.8.1
* @see RateLimit
*/
public class OpenAiAudioTranscriptionResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
public class OpenAiAudioTranscriptionResponseMetadata extends MutableResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }";

View File

@@ -22,6 +22,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
@@ -200,11 +201,11 @@ public class PostgresMlEmbeddingModel extends AbstractEmbeddingModel implements
}
}
var metadata = new EmbeddingResponseMetadata(
Map.of("transformer", optionsToUse.getTransformer(), "vector-type", optionsToUse.getVectorType().name(),
"kwargs", ModelOptionsUtils.toJsonString(optionsToUse.getKwargs())));
return new EmbeddingResponse(data, metadata);
Map<String, Object> embeddingMetadata = Map.of("transformer", optionsToUse.getTransformer(), "vector-type",
optionsToUse.getVectorType().name(), "kwargs",
ModelOptionsUtils.toJsonString(optionsToUse.getKwargs()));
var embeddingResponseMetadata = new EmbeddingResponseMetadata("unknown", new EmptyUsage(), embeddingMetadata);
return new EmbeddingResponse(data, embeddingResponseMetadata);
}
/**

View File

@@ -30,6 +30,7 @@ import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.postgresml.PostgresMlEmbeddingModel.VectorType;
import org.testcontainers.containers.PostgreSQLContainer;
@@ -144,8 +145,21 @@ class PostgresMlEmbeddingModelIT {
assertThat(embeddingResponse).isNotNull();
assertThat(embeddingResponse.getResults()).hasSize(3);
assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(
Map.of("transformer", "distilbert-base-uncased", "vector-type", vectorType, "kwargs", "{}"));
EmbeddingResponseMetadata metadata = embeddingResponse.getMetadata();
assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys")
.containsExactlyInAnyOrder("transformer", "vector-type", "kwargs");
assertThat(metadata.get("transformer").toString())
.as("Transformer in metadata should be 'distilbert-base-uncased'")
.isEqualTo("distilbert-base-uncased");
assertThat(metadata.get("vector-type").toString())
.as("Vector type in metadata should match expected vector type")
.isEqualTo(vectorType);
assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{}'").isEqualTo("{}");
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768);
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
@@ -170,8 +184,22 @@ class PostgresMlEmbeddingModelIT {
assertThat(embeddingResponse).isNotNull();
assertThat(embeddingResponse.getResults()).hasSize(3);
assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(Map.of("transformer",
"distilbert-base-uncased", "vector-type", VectorType.PG_VECTOR.name(), "kwargs", "{}"));
EmbeddingResponseMetadata metadata = embeddingResponse.getMetadata();
assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys")
.containsExactlyInAnyOrder("transformer", "vector-type", "kwargs");
assertThat(metadata.get("transformer").toString())
.as("Transformer in metadata should be 'distilbert-base-uncased'")
.isEqualTo("distilbert-base-uncased");
assertThat(metadata.get("vector-type").toString())
.as("Vector type in metadata should match expected vector type")
.isEqualTo(VectorType.PG_VECTOR.name());
assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{}'").isEqualTo("{}");
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768);
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
@@ -192,8 +220,20 @@ class PostgresMlEmbeddingModelIT {
assertThat(embeddingResponse).isNotNull();
assertThat(embeddingResponse.getResults()).hasSize(3);
assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(Map.of("transformer",
"intfloat/e5-small", "vector-type", VectorType.PG_ARRAY.name(), "kwargs", "{\"device\":\"cpu\"}"));
metadata = embeddingResponse.getMetadata();
assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys")
.containsExactlyInAnyOrder("transformer", "vector-type", "kwargs");
assertThat(metadata.get("transformer").toString()).as("Transformer in metadata should be 'intfloat/e5-small'")
.isEqualTo("intfloat/e5-small");
assertThat(metadata.get("vector-type").toString()).as("Vector type in metadata should be PG_ARRAY")
.isEqualTo(VectorType.PG_ARRAY.name());
assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{\"device\":\"cpu\"}'")
.isEqualTo("{\"device\":\"cpu\"}");
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(384);

View File

@@ -119,7 +119,7 @@ public class StabilityAiImageModel implements ImageModel {
new StabilityAiImageGenerationMetadata(entry.finishReason(), entry.seed()));
}).toList();
return new ImageResponse(imageGenerationList, ImageResponseMetadata.NULL);
return new ImageResponse(imageGenerationList, new ImageResponseMetadata());
}
private StabilityAiImageOptions convertOptions(ImageOptions runtimeOptions) {

View File

@@ -24,6 +24,7 @@ import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingResponse;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Christian Tzolov
@@ -76,7 +77,7 @@ public class TransformersEmbeddingModelTests {
embeddingModel.afterPropertiesSet();
EmbeddingResponse embed = embeddingModel.embedForResponse(List.of("Hello world", "World is big"));
assertThat(embed.getResults()).hasSize(2);
assertThat(embed.getMetadata()).isEmpty();
assertTrue(embed.getMetadata().isEmpty(), "Expected embed metadata to be empty, but it was not.");
assertThat(embed.getResults().get(0).getOutput()).hasSize(384);
assertThat(DF.format(embed.getResults().get(0).getOutput().get(0))).isEqualTo(DF.format(-0.19744634628295898));

View File

@@ -0,0 +1,28 @@
package org.springframework.ai.vertexai.embedding;
import org.springframework.ai.chat.metadata.Usage;
public class VertexAiEmbeddingUsage implements Usage {
private final Integer totalTokens;
public VertexAiEmbeddingUsage(Integer totalTokens) {
this.totalTokens = totalTokens;
}
@Override
public Long getPromptTokens() {
return 0L;
}
@Override
public Long getGenerationTokens() {
return 0L;
}
@Override
public Long getTotalTokens() {
return Long.valueOf(this.totalTokens);
}
}

View File

@@ -24,6 +24,7 @@ import com.google.protobuf.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.DocumentEmbeddingModel;
import org.springframework.ai.embedding.DocumentEmbeddingRequest;
@@ -35,6 +36,7 @@ import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.MultimodalInstanceBuilder;
@@ -230,20 +232,18 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
String deploymentModelId = embeddingResponse.getDeployedModelId();
EmbeddingResponseMetadata responseMetadata = generateResponseMetadata(mergedOptions.getModel(), -1);
responseMetadata.put("deployment-model-id",
Map<String, Object> metadataToUse = Map.of("deployment-model-id",
StringUtils.hasText(deploymentModelId) ? deploymentModelId : "unknown");
return new EmbeddingResponse(embeddingList, generateResponseMetadata(mergedOptions.getModel(), 0));
EmbeddingResponseMetadata responseMetadata = generateResponseMetadata(mergedOptions.getModel(), 0,
metadataToUse);
return new EmbeddingResponse(embeddingList, responseMetadata);
}
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.put("model", model);
metadata.put("total-tokens", tokenCount);
return metadata;
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens,
Map<String, Object> metadataToUse) {
Usage usage = new VertexAiEmbeddingUsage(totalTokens);
return new EmbeddingResponseMetadata(model, usage, metadataToUse);
}
@Override

View File

@@ -20,6 +20,7 @@ import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.protobuf.Value;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
@@ -32,6 +33,7 @@ import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetai
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@@ -135,10 +137,11 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
}
}
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) {
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.put("model", model);
metadata.put("total-tokens", tokenCount);
metadata.setModel(model);
Usage usage = new VertexAiEmbeddingUsage(totalTokens);
metadata.setUsage(usage);
return metadata;
}

View File

@@ -33,10 +33,10 @@ import org.springframework.core.io.ClassPathResource;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
@SpringBootTest(classes = VertexAiMultimodelEmbeddingModelIT.Config.class)
@SpringBootTest(classes = VertexAiMultimodalEmbeddingModelIT.Config.class)
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
class VertexAiMultimodelEmbeddingModelIT {
class VertexAiMultimodalEmbeddingModelIT {
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-embeddings-api
@@ -68,8 +68,13 @@ class VertexAiMultimodelEmbeddingModelIT {
.isEqualTo(embeddingRequest.getInstructions().get(1).getId());
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1408);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
assertThat(embeddingResponse.getMetadata().getModel())
.as("Model in metadata should be 'multimodalembedding@001'")
.isEqualTo("multimodalembedding@001");
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens())
.as("Total tokens in metadata should be 0")
.isEqualTo("0");
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
}
@@ -90,8 +95,8 @@ class VertexAiMultimodelEmbeddingModelIT {
.isEqualTo(MimeTypeUtils.TEXT_PLAIN);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
}
@@ -113,8 +118,8 @@ class VertexAiMultimodelEmbeddingModelIT {
.isEqualTo(MimeTypeUtils.TEXT_PLAIN);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
}
@@ -139,8 +144,8 @@ class VertexAiMultimodelEmbeddingModelIT {
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
}
@@ -164,8 +169,8 @@ class VertexAiMultimodelEmbeddingModelIT {
.isEqualTo(new MimeType("video", "mp4"));
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
}
@@ -198,8 +203,8 @@ class VertexAiMultimodelEmbeddingModelIT {
.isEqualTo(EmbeddingResultMetadata.ModalityType.VIDEO);
assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(1408);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
}

View File

@@ -53,8 +53,12 @@ class VertexAiTextEmbeddingModelIT {
assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768);
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", modelName);
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 5);
assertThat(embeddingResponse.getMetadata().getModel()).as("Model name in metadata should match expected model")
.isEqualTo(modelName);
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens())
.as("Total tokens in metadata should be 5")
.isEqualTo(5L);
assertThat(embeddingModel.dimensions()).isEqualTo(768);
}

View File

@@ -15,37 +15,6 @@
*/
package org.springframework.ai.vertexai.gemini;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.google.cloud.vertexai.VertexAI;
@@ -63,10 +32,39 @@ import com.google.cloud.vertexai.generativeai.PartMaker;
import com.google.cloud.vertexai.generativeai.ResponseStream;
import com.google.protobuf.Struct;
import com.google.protobuf.util.JsonFormat;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ChatModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractToolCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* @author Christian Tzolov
* @author Grogdunn
@@ -244,8 +242,8 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport<GenerateCon
}
}
private VertexAiChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) {
return new VertexAiChatResponseMetadata(new VertexAiUsage(response.getUsageMetadata()));
private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) {
return ChatResponseMetadata.builder().withUsage(new VertexAiUsage(response.getUsageMetadata())).build();
}
@JsonInclude(Include.NON_NULL)

View File

@@ -1,40 +0,0 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vertexai.gemini.metadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import java.util.HashMap;
/**
* @author Christian Tzolov
* @since 0.8.1
*/
public class VertexAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
private final VertexAiUsage usage;
public VertexAiChatResponseMetadata(VertexAiUsage usage) {
this.usage = usage;
}
@Override
public Usage getUsage() {
return this.usage;
}
}

View File

@@ -23,6 +23,7 @@ import java.util.stream.Collectors;
import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.model.Content;
@@ -127,15 +128,17 @@ public class QuestionAnswerAdvisor implements RequestResponseAdvisor {
@Override
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
response.getMetadata().put(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
return response;
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response);
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
return chatResponseBuilder.build();
}
@Override
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
return fluxResponse.map(cr -> {
cr.getMetadata().put(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
return cr;
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
return chatResponseBuilder.build();
});
}

View File

@@ -15,40 +15,47 @@
*/
package org.springframework.ai.chat.metadata;
import org.springframework.ai.model.AbstractResponseMetadata;
import org.springframework.ai.model.ResponseMetadata;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
/**
* Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI
* response.
* Models common AI provider metadata returned in an AI response.
*
* @author John Blum
* @author Thomas Vitale
* @since 0.7.0
* @author Mark Pollack
* @since 1.0.0
*/
public interface ChatResponseMetadata extends ResponseMetadata {
public class ChatResponseMetadata extends AbstractResponseMetadata implements ResponseMetadata {
class DefaultChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
private String id = ""; // Set to blank to preserve backward compat with previous
// interface default methods
}
private String model = "";
ChatResponseMetadata NULL = new DefaultChatResponseMetadata();
private RateLimit rateLimit = new EmptyRateLimit();
private Usage usage = new EmptyUsage();
private PromptMetadata promptMetadata = PromptMetadata.empty();
/**
* A unique identifier for the chat completion operation.
* @return unique operation identifier.
*/
default String getId() {
return "";
public String getId() {
return this.id;
}
/**
* The model that handled the request.
* @return the model that handled the request.
*/
default String getModel() {
return "";
public String getModel() {
return this.model;
}
/**
@@ -56,8 +63,8 @@ public interface ChatResponseMetadata extends ResponseMetadata {
* @return AI provider specific metadata on rate limits.
* @see RateLimit
*/
default RateLimit getRateLimit() {
return new EmptyRateLimit();
public RateLimit getRateLimit() {
return this.rateLimit;
}
/**
@@ -65,12 +72,90 @@ public interface ChatResponseMetadata extends ResponseMetadata {
* @return AI provider specific metadata on API usage.
* @see Usage
*/
default Usage getUsage() {
return new EmptyUsage();
public Usage getUsage() {
return this.usage;
}
default PromptMetadata getPromptMetadata() {
return PromptMetadata.empty();
/**
* Returns the prompt metadata gathered by the AI during request processing.
* @return the prompt metadata.
*/
public PromptMetadata getPromptMetadata() {
return this.promptMetadata;
}
public static class Builder {
private final ChatResponseMetadata chatResponseMetadata;
public Builder() {
this.chatResponseMetadata = new ChatResponseMetadata();
}
public Builder withMetadata(Map<String, Object> mapToCopy) {
this.chatResponseMetadata.map.putAll(mapToCopy);
return this;
}
public Builder withKeyValue(String key, Object value) {
this.chatResponseMetadata.map.put(key, value);
return this;
}
public Builder withId(String id) {
this.chatResponseMetadata.id = id;
return this;
}
public Builder withModel(String model) {
this.chatResponseMetadata.model = model;
return this;
}
public Builder withRateLimit(RateLimit rateLimit) {
this.chatResponseMetadata.rateLimit = rateLimit;
return this;
}
public Builder withUsage(Usage usage) {
this.chatResponseMetadata.usage = usage;
return this;
}
public Builder withPromptMetadata(PromptMetadata promptMetadata) {
this.chatResponseMetadata.promptMetadata = promptMetadata;
return this;
}
public ChatResponseMetadata build() {
return this.chatResponseMetadata;
}
}
public static Builder builder() {
return new Builder();
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof ChatResponseMetadata that))
return false;
return Objects.equals(id, that.id) && Objects.equals(model, that.model)
&& Objects.equals(rateLimit, that.rateLimit) && Objects.equals(usage, that.usage)
&& Objects.equals(promptMetadata, that.promptMetadata);
}
@Override
public int hashCode() {
return Objects.hash(id, model, rateLimit, usage, promptMetadata);
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getId(), getUsage(), getRateLimit());
}
}

View File

@@ -16,7 +16,9 @@
package org.springframework.ai.chat.model;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.springframework.ai.model.ModelResponse;
import org.springframework.util.CollectionUtils;
@@ -40,7 +42,7 @@ public class ChatResponse implements ModelResponse<Generation> {
* provider.
*/
public ChatResponse(List<Generation> generations) {
this(generations, ChatResponseMetadata.NULL);
this(generations, new ChatResponseMetadata());
}
/**
@@ -107,4 +109,44 @@ public class ChatResponse implements ModelResponse<Generation> {
return Objects.hash(chatResponseMetadata, generations);
}
public static ChatResponse.Builder builder() {
return new ChatResponse.Builder();
}
public static class Builder {
private List<Generation> generations;
private ChatResponseMetadata.Builder chatResponseMetadataBuilder;
private Builder() {
this.chatResponseMetadataBuilder = ChatResponseMetadata.builder();
}
public Builder from(ChatResponse other) {
this.generations = other.generations;
Set<Map.Entry<String, Object>> entries = other.chatResponseMetadata.entrySet();
for (Map.Entry<String, Object> entry : entries) {
this.chatResponseMetadataBuilder.withKeyValue(entry.getKey(), entry.getValue());
}
return this;
}
public Builder withMetadata(String key, Object value) {
this.chatResponseMetadataBuilder.withKeyValue(key, value);
return this;
}
public Builder withGenerations(List<Generation> generations) {
this.generations = generations;
return this;
}
public ChatResponse build() {
return new ChatResponse(generations, chatResponseMetadataBuilder.build());
}
}
}

View File

@@ -15,24 +15,20 @@
*/
package org.springframework.ai.embedding;
import java.io.Serial;
import java.util.HashMap;
import java.util.Map;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.model.AbstractResponseMetadata;
import org.springframework.ai.model.ResponseMetadata;
import java.util.Map;
/**
* Common AI provider metadata returned in an embedding response.
*
* @author Christian Tzolov
* @author Thomas Vitale
*/
public class EmbeddingResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
@Serial
private static final long serialVersionUID = 1L;
public class EmbeddingResponseMetadata extends AbstractResponseMetadata implements ResponseMetadata {
private String model;
@@ -42,12 +38,15 @@ public class EmbeddingResponseMetadata extends HashMap<String, Object> implement
}
public EmbeddingResponseMetadata(String model, Usage usage) {
this.model = model;
this.usage = usage;
this(model, usage, Map.of());
}
public EmbeddingResponseMetadata(Map<String, ?> metadata) {
super(metadata);
public EmbeddingResponseMetadata(String model, Usage usage, Map<String, Object> metadata) {
this.model = model;
this.usage = usage;
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
this.map.put(entry.getKey(), entry.getValue());
}
}
/**

View File

@@ -43,7 +43,7 @@ public class ImageResponse implements ModelResponse<ImageGeneration> {
* provider.
*/
public ImageResponse(List<ImageGeneration> generations) {
this(generations, ImageResponseMetadata.NULL);
this(generations, new ImageResponseMetadata());
}
/**

View File

@@ -15,6 +15,7 @@
*/
package org.springframework.ai.image;
import org.springframework.ai.model.MutableResponseMetadata;
import org.springframework.ai.model.ResponseMetadata;
import java.util.HashMap;
@@ -28,16 +29,20 @@ import java.util.HashMap;
* @author Thomas Vitale
* @since 1.0.0
*/
public interface ImageResponseMetadata extends ResponseMetadata {
public class ImageResponseMetadata extends MutableResponseMetadata {
class DefaultImageResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
private Long created;
public ImageResponseMetadata() {
this.created = System.currentTimeMillis();
}
ImageResponseMetadata NULL = new DefaultImageResponseMetadata();
public ImageResponseMetadata(Long created) {
this.created = created;
}
default Long getCreated() {
return System.currentTimeMillis();
public Long getCreated() {
return this.created;
}
}

View File

@@ -0,0 +1,76 @@
package org.springframework.ai.model;
import io.micrometer.common.lang.NonNull;
import io.micrometer.common.lang.Nullable;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
public class AbstractResponseMetadata {
protected static final String AI_METADATA_STRING = "{ id: %2$s, usage: %3$s, rateLimit: %4$s }";
protected final Map<String, Object> map = new ConcurrentHashMap<>();
/**
* Gets an entry from the context. Returns {@code null} when entry is not present.
* @param key key
* @param <T> value type
* @return entry or {@code null} if not present
*/
@Nullable
public <T> T get(String key) {
return (T) this.map.get(key);
}
/**
* Gets an entry from the context. Throws exception when entry is not present.
* @param key key
* @param <T> value type
* @return entry
* @throws IllegalArgumentException if not present
*/
@NonNull
public <T> T getRequired(Object key) {
T object = (T) this.map.get(key);
if (object == null) {
throw new IllegalArgumentException("Context does not have an entry for key [" + key + "]");
}
return object;
}
/**
* Checks if context contains a key.
* @param key key
* @return {@code true} when the context contains the entry with the given key
*/
public boolean containsKey(Object key) {
return this.map.containsKey(key);
}
/**
* Returns an element or default if not present.
* @param key key
* @param defaultObject default object to return
* @param <T> value type
* @return object or default if not present
*/
public <T> T getOrDefault(Object key, T defaultObject) {
return (T) this.map.getOrDefault(key, defaultObject);
}
public Set<Map.Entry<String, Object>> entrySet() {
return Collections.unmodifiableMap(this.map).entrySet();
}
public Set<String> keySet() {
return Collections.unmodifiableSet(this.map.keySet());
}
public boolean isEmpty() {
return this.map.isEmpty();
}
}

View File

@@ -0,0 +1,126 @@
package org.springframework.ai.model;
import io.micrometer.common.lang.NonNull;
import io.micrometer.common.lang.Nullable;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
public class MutableResponseMetadata implements ResponseMetadata {
private final Map<String, Object> map = new ConcurrentHashMap<>();
/**
* Puts an element to the context.
* @param key key
* @param object value
* @param <T> value type
* @return this for chaining
*/
public <T> MutableResponseMetadata put(String key, T object) {
this.map.put(key, object);
return this;
}
/**
* Gets an entry from the context. Returns {@code null} when entry is not present.
* @param key key
* @param <T> value type
* @return entry or {@code null} if not present
*/
@Override
@Nullable
public <T> T get(String key) {
return (T) this.map.get(key);
}
/**
* Removes an entry from the context.
* @param key key by which to remove an entry
* @return the previous value associated with the key, or null if there was no mapping
* for the key
*/
public Object remove(Object key) {
return this.map.remove(key);
}
/**
* Gets an entry from the context. Throws exception when entry is not present.
* @param key key
* @param <T> value type
* @throws IllegalArgumentException if not present
* @return entry
*/
@Override
@NonNull
public <T> T getRequired(Object key) {
T object = (T) this.map.get(key);
if (object == null) {
throw new IllegalArgumentException("Context does not have an entry for key [" + key + "]");
}
return object;
}
/**
* Checks if context contains a key.
* @param key key
* @return {@code true} when the context contains the entry with the given key
*/
@Override
public boolean containsKey(Object key) {
return this.map.containsKey(key);
}
/**
* Returns an element or default if not present.
* @param key key
* @param defaultObject default object to return
* @param <T> value type
* @return object or default if not present
*/
@Override
public <T> T getOrDefault(Object key, T defaultObject) {
return (T) this.map.getOrDefault(key, defaultObject);
}
@Override
public Set<Map.Entry<String, Object>> entrySet() {
return Collections.unmodifiableMap(this.map).entrySet();
}
public Set<String> keySet() {
return Collections.unmodifiableSet(this.map.keySet());
}
@Override
public boolean isEmpty() {
return this.map.isEmpty();
}
/**
* Returns an element or calls a mapping function if entry not present. The function
* will insert the value to the map.
* @param key key
* @param mappingFunction mapping function
* @param <T> value type
* @return object or one derived from the mapping function if not present
*/
public <T> T computeIfAbsent(String key, Function<Object, ? extends T> mappingFunction) {
return (T) this.map.computeIfAbsent(key, mappingFunction);
}
/**
* Clears the entries from the context.
*/
public void clear() {
this.map.clear();
}
public Map<String, Object> getRawMap() {
return map;
}
}

View File

@@ -15,18 +15,77 @@
*/
package org.springframework.ai.model;
import io.micrometer.common.lang.NonNull;
import io.micrometer.common.lang.Nullable;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
/**
* Interface representing metadata associated with an AI model's response. This interface
* is designed to provide additional information about the generative response from an AI
* model, including processing details and model-specific data. It serves as a value
* object within the core domain, enhancing the understanding and management of AI model
* responses in various applications.
* Interface representing metadata associated with an AI model's response.
*
* @author Mark Pollack
* @since 0.8.0
* @since 1.0.0
*/
public interface ResponseMetadata extends Map<String, Object> {
public interface ResponseMetadata {
/**
* Gets an entry from the context. Returns {@code null} when entry is not present.
* @param key key
* @param <T> value type
* @return entry or {@code null} if not present
*/
@Nullable
<T> T get(String key);
/**
* Gets an entry from the context. Throws exception when entry is not present.
* @param key key
* @param <T> value type
* @throws IllegalArgumentException if not present
* @return entry
*/
@NonNull
<T> T getRequired(Object key);
/**
* Checks if context contains a key.
* @param key key
* @return {@code true} when the context contains the entry with the given key
*/
boolean containsKey(Object key);
/**
* Returns an element or default if not present.
* @param key key
* @param defaultObject default object to return
* @param <T> value type
* @return object or default if not present
*/
<T> T getOrDefault(Object key, T defaultObject);
/**
* Returns an element or default if not present.
* @param key key
* @param defaultObjectSupplier supplier for default object to return
* @param <T> value type
* @return object or default if not present
* @since 1.11.0
*/
default <T> T getOrDefault(String key, Supplier<T> defaultObjectSupplier) {
T value = get(key);
return value != null ? value : defaultObjectSupplier.get();
}
Set<Map.Entry<String, Object>> entrySet();
public Set<String> keySet();
/**
* Returns {@code true} if this map contains no key-value mappings.
* @return {@code true} if this map contains no key-value mappings
*/
boolean isEmpty();
}

View File

@@ -136,9 +136,7 @@ public abstract class AbstractFunctionCallSupport<Msg, Req, Resp> {
// The chat completion tool call requires the complete conversation
// history. Including the initial user message.
List<Msg> conversationHistory = new ArrayList<>();
conversationHistory.addAll(this.doGetUserMessages(request));
List<Msg> conversationHistory = new ArrayList<>(this.doGetUserMessages(request));
Msg responseMessage = this.doGetToolResponseMessage(response);
@@ -164,9 +162,7 @@ public abstract class AbstractFunctionCallSupport<Msg, Req, Resp> {
// The chat completion tool call requires the complete conversation
// history. Including the initial user message.
List<Msg> conversationHistory = new ArrayList<>();
conversationHistory.addAll(this.doGetUserMessages(request));
List<Msg> conversationHistory = new ArrayList<>(this.doGetUserMessages(request));
Msg responseMessage = this.doGetToolResponseMessage(resp);

View File

@@ -29,7 +29,6 @@ import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -58,8 +57,7 @@ public class ChatClientResponseEntityTests {
@Test
public void responseEntityTest() {
ChatResponseMetadata metadata = new DefaultChatResponseMetadata();
metadata.put("key1", "value1");
ChatResponseMetadata metadata = ChatResponseMetadata.builder().withKeyValue("key1", "value1").build();
var chatResponse = new ChatResponse(List.of(new Generation("""
{"name":"John", "age":30}
@@ -75,7 +73,7 @@ public class ChatClientResponseEntityTests {
.responseEntity(MyBean.class);
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1");
assertThat(responseEntity.getResponse().getMetadata().get("key1").toString()).isEqualTo("value1");
assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));

View File

@@ -112,8 +112,8 @@ public class VertexAiTextEmbeddingModelAutoConfigurationIT {
.isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT);
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
assertThat(embeddingResponse.getMetadata().getUsage()).isEqualTo(0);
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);