Update to ResponseMetadata design
* Remove inheritance from HashMap * No more subclasses per model provider * Builder class for ChatResponse
This commit is contained in:
@@ -15,14 +15,6 @@
|
||||
*/
|
||||
package org.springframework.ai.anthropic;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi;
|
||||
@@ -32,12 +24,13 @@ import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionResponse;
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock;
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType;
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
|
||||
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
|
||||
import org.springframework.ai.anthropic.metadata.AnthropicUsage;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -52,10 +45,17 @@ import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* The {@link ChatModel} implementation for the Anthropic service.
|
||||
*
|
||||
@@ -228,7 +228,20 @@ public class AnthropicChatModel extends AbstractToolCallSupport<ChatCompletionRe
|
||||
.withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
|
||||
}).toList();
|
||||
|
||||
return new ChatResponse(generations, AnthropicChatResponseMetadata.from(chatCompletion));
|
||||
return new ChatResponse(generations, from(chatCompletion));
|
||||
}
|
||||
|
||||
private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
|
||||
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
|
||||
AnthropicUsage usage = AnthropicUsage.from(result.usage());
|
||||
return ChatResponseMetadata.builder()
|
||||
.withId(result.id())
|
||||
.withModel(result.model())
|
||||
.withUsage(usage)
|
||||
.withKeyValue("stop-reason", result.stopReason())
|
||||
.withKeyValue("stop-sequence", result.stopSequence())
|
||||
.withKeyValue("type", result.type())
|
||||
.build();
|
||||
}
|
||||
|
||||
private String fromMediaData(Object mediaData) {
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.anthropic.metadata;
|
||||
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.EmptyRateLimit;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatResponseMetadata} implementation for {@literal AnthropicApi}.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @see ChatResponseMetadata
|
||||
* @see RateLimit
|
||||
* @see Usage
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class AnthropicChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }";
|
||||
|
||||
public static AnthropicChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
|
||||
Assert.notNull(result, "Anthropic ChatCompletionResult must not be null");
|
||||
AnthropicUsage usage = AnthropicUsage.from(result.usage());
|
||||
return new AnthropicChatResponseMetadata(result.id(), result.model(), usage);
|
||||
}
|
||||
|
||||
private final String id;
|
||||
|
||||
private final String model;
|
||||
|
||||
@Nullable
|
||||
private RateLimit rateLimit;
|
||||
|
||||
private final Usage usage;
|
||||
|
||||
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage) {
|
||||
this(id, model, usage, null);
|
||||
}
|
||||
|
||||
protected AnthropicChatResponseMetadata(String id, String model, AnthropicUsage usage,
|
||||
@Nullable AnthropicRateLimit rateLimit) {
|
||||
this.id = id;
|
||||
this.model = model;
|
||||
this.usage = usage;
|
||||
this.rateLimit = rateLimit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getId() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModel() {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Nullable
|
||||
public RateLimit getRateLimit() {
|
||||
RateLimit rl = this.rateLimit;
|
||||
return rl != null ? rl : new EmptyRateLimit();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Usage getUsage() {
|
||||
Usage usage = this.usage;
|
||||
return usage != null ? usage : new EmptyUsage();
|
||||
}
|
||||
|
||||
public AnthropicChatResponseMetadata withRateLimit(RateLimit rateLimit) {
|
||||
this.rateLimit = rateLimit;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -15,33 +15,6 @@
|
||||
*/
|
||||
package org.springframework.ai.azure.openai;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractToolCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import com.azure.ai.openai.OpenAIClient;
|
||||
import com.azure.ai.openai.models.ChatChoice;
|
||||
import com.azure.ai.openai.models.ChatCompletions;
|
||||
@@ -68,10 +41,36 @@ import com.azure.ai.openai.models.FunctionCall;
|
||||
import com.azure.ai.openai.models.FunctionDefinition;
|
||||
import com.azure.core.util.BinaryData;
|
||||
import com.azure.core.util.IterableStream;
|
||||
|
||||
import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractToolCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
|
||||
* {@link OpenAIClient}.
|
||||
@@ -151,8 +150,22 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport<ChatCompletion
|
||||
|
||||
PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);
|
||||
|
||||
return new ChatResponse(generations,
|
||||
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
|
||||
return new ChatResponse(generations, from(chatCompletions, promptFilterMetadata));
|
||||
}
|
||||
|
||||
public static ChatResponseMetadata from(ChatCompletions chatCompletions, PromptMetadata promptFilterMetadata) {
|
||||
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
|
||||
String id = chatCompletions.getId();
|
||||
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
|
||||
ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata.builder()
|
||||
.withId(id)
|
||||
.withUsage(usage)
|
||||
.withModel(chatCompletions.getModel())
|
||||
.withPromptMetadata(promptFilterMetadata)
|
||||
.withKeyValue("system-fingerprint", chatCompletions.getSystemFingerprint())
|
||||
.build();
|
||||
|
||||
return chatResponseMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.azure.openai.metadata;
|
||||
|
||||
import com.azure.ai.openai.models.ChatCompletions;
|
||||
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatResponseMetadata} implementation for
|
||||
* {@literal Microsoft Azure OpenAI Service}.
|
||||
*
|
||||
* @author John Blum
|
||||
* @author Thomas Vitale
|
||||
* @see ChatResponseMetadata
|
||||
* @since 0.7.1
|
||||
*/
|
||||
public class AzureOpenAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";
|
||||
|
||||
@SuppressWarnings("all")
|
||||
public static AzureOpenAiChatResponseMetadata from(ChatCompletions chatCompletions,
|
||||
PromptMetadata promptFilterMetadata) {
|
||||
Assert.notNull(chatCompletions, "Azure OpenAI ChatCompletions must not be null");
|
||||
String id = chatCompletions.getId();
|
||||
AzureOpenAiUsage usage = AzureOpenAiUsage.from(chatCompletions);
|
||||
AzureOpenAiChatResponseMetadata chatResponseMetadata = new AzureOpenAiChatResponseMetadata(id, usage,
|
||||
promptFilterMetadata);
|
||||
return chatResponseMetadata;
|
||||
}
|
||||
|
||||
private final String id;
|
||||
|
||||
private final Usage usage;
|
||||
|
||||
private final PromptMetadata promptMetadata;
|
||||
|
||||
protected AzureOpenAiChatResponseMetadata(String id, AzureOpenAiUsage usage, PromptMetadata promptMetadata) {
|
||||
this.id = id;
|
||||
this.usage = usage;
|
||||
this.promptMetadata = promptMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getId() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Usage getUsage() {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptMetadata getPromptMetadata() {
|
||||
return this.promptMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package org.springframework.ai.azure.openai.metadata;
|
||||
|
||||
import com.azure.ai.openai.models.ImageGenerations;
|
||||
import org.springframework.ai.image.ImageResponseMetadata;
|
||||
import org.springframework.ai.model.MutableResponseMetadata;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -15,7 +16,7 @@ import java.util.Objects;
|
||||
* @author Benoit Moussaud
|
||||
* @since 1.0.0 M1
|
||||
*/
|
||||
public class AzureOpenAiImageResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
|
||||
public class AzureOpenAiImageResponseMetadata extends ImageResponseMetadata {
|
||||
|
||||
private final Long created;
|
||||
|
||||
|
||||
@@ -15,14 +15,6 @@
|
||||
*/
|
||||
package org.springframework.ai.mistralai;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
@@ -31,6 +23,7 @@ import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -44,7 +37,7 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.mistralai.metadata.MistralAiChatResponseMetadata;
|
||||
import org.springframework.ai.mistralai.metadata.MistralAiUsage;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractToolCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
@@ -53,10 +46,17 @@ import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* @author Ricken Bazolo
|
||||
* @author Christian Tzolov
|
||||
@@ -134,10 +134,21 @@ public class MistralAiChatModel extends AbstractToolCallSupport<ChatCompletion>
|
||||
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
|
||||
.toList();
|
||||
|
||||
return new ChatResponse(generations, MistralAiChatResponseMetadata.from(chatCompletion));
|
||||
return new ChatResponse(generations, from(chatCompletion));
|
||||
});
|
||||
}
|
||||
|
||||
public static ChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
|
||||
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
|
||||
MistralAiUsage usage = MistralAiUsage.from(result.usage());
|
||||
return ChatResponseMetadata.builder()
|
||||
.withId(result.id())
|
||||
.withModel(result.model())
|
||||
.withUsage(usage)
|
||||
.withKeyValue("created", result.created())
|
||||
.build();
|
||||
}
|
||||
|
||||
private Map<String, Object> toMap(String id, ChatCompletion.Choice choice) {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
package org.springframework.ai.mistralai.metadata;
|
||||
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatResponseMetadata} implementation for {@literal Mistral AI}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @see ChatResponseMetadata
|
||||
* @see Usage
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class MistralAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s }";
|
||||
|
||||
public static MistralAiChatResponseMetadata from(MistralAiApi.ChatCompletion result) {
|
||||
Assert.notNull(result, "Mistral AI ChatCompletion must not be null");
|
||||
MistralAiUsage usage = MistralAiUsage.from(result.usage());
|
||||
return new MistralAiChatResponseMetadata(result.id(), result.model(), usage);
|
||||
}
|
||||
|
||||
private final String id;
|
||||
|
||||
private final String model;
|
||||
|
||||
private final Usage usage;
|
||||
|
||||
protected MistralAiChatResponseMetadata(String id, String model, MistralAiUsage usage) {
|
||||
this.id = id;
|
||||
this.model = model;
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getId() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModel() {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Usage getUsage() {
|
||||
Usage usage = this.usage;
|
||||
return usage != null ? usage : new EmptyUsage();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -15,27 +15,27 @@
|
||||
*/
|
||||
package org.springframework.ai.ollama;
|
||||
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.ollama.metadata.OllamaChatResponseMetadata;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.ollama.metadata.OllamaUsage;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link ChatModel} implementation for {@literal Ollama}.
|
||||
@@ -102,7 +102,23 @@ public class OllamaChatModel implements ChatModel {
|
||||
if (response.promptEvalCount() != null && response.evalCount() != null) {
|
||||
generator = generator.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
|
||||
}
|
||||
return new ChatResponse(List.of(generator), OllamaChatResponseMetadata.from(response));
|
||||
return new ChatResponse(List.of(generator), from(response));
|
||||
}
|
||||
|
||||
public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
|
||||
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
|
||||
return ChatResponseMetadata.builder()
|
||||
.withUsage(OllamaUsage.from(response))
|
||||
.withModel(response.model())
|
||||
.withKeyValue("created-at", response.createdAt())
|
||||
.withKeyValue("eval-duration", response.evalDuration())
|
||||
.withKeyValue("eval-count", response.evalCount())
|
||||
.withKeyValue("load-duration", response.loadDuration())
|
||||
.withKeyValue("eval-duration", response.promptEvalDuration())
|
||||
.withKeyValue("eval-count", response.promptEvalCount())
|
||||
.withKeyValue("total-duration", response.totalDuration())
|
||||
.withKeyValue("done", response.done())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -116,7 +132,7 @@ public class OllamaChatModel implements ChatModel {
|
||||
if (Boolean.TRUE.equals(chunk.done())) {
|
||||
generation = generation.withGenerationMetadata(ChatGenerationMetadata.from("unknown", null));
|
||||
}
|
||||
return new ChatResponse(List.of(generation), OllamaChatResponseMetadata.from(chunk));
|
||||
return new ChatResponse(List.of(generation), from(chunk));
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.ollama.metadata;
|
||||
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatResponseMetadata} implementation for {@literal Ollama}
|
||||
*
|
||||
* @see ChatResponseMetadata
|
||||
* @author Fu Cheng
|
||||
*/
|
||||
public class OllamaChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, usage: %2$s, rateLimit: %3$s }";
|
||||
|
||||
public static OllamaChatResponseMetadata from(OllamaApi.ChatResponse response) {
|
||||
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
|
||||
Usage usage = OllamaUsage.from(response);
|
||||
return new OllamaChatResponseMetadata(usage);
|
||||
}
|
||||
|
||||
private final Usage usage;
|
||||
|
||||
protected OllamaChatResponseMetadata(Usage usage) {
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Usage getUsage() {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getUsage(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
package org.springframework.ai.openai;
|
||||
|
||||
public class ImageResponseMetadata {
|
||||
public interface ImageResponseMetadata {
|
||||
|
||||
}
|
||||
|
||||
@@ -15,14 +15,6 @@
|
||||
*/
|
||||
package org.springframework.ai.openai;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
@@ -30,6 +22,7 @@ import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
@@ -49,7 +42,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ChatCom
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
|
||||
import org.springframework.ai.openai.metadata.OpenAiUsage;
|
||||
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
@@ -57,10 +50,17 @@ import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.MimeType;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
|
||||
* backed by {@link OpenAiApi}.
|
||||
@@ -165,7 +165,7 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
|
||||
}
|
||||
|
||||
// Non function calling.
|
||||
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
|
||||
RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity);
|
||||
|
||||
List<Choice> choices = chatCompletion.choices();
|
||||
if (choices == null) {
|
||||
@@ -186,11 +186,22 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
|
||||
|
||||
}).toList();
|
||||
|
||||
return new ChatResponse(generations,
|
||||
OpenAiChatResponseMetadata.from(completionEntity.getBody()).withRateLimit(rateLimits));
|
||||
return new ChatResponse(generations, from(completionEntity.getBody(), rateLimit));
|
||||
});
|
||||
}
|
||||
|
||||
public static ChatResponseMetadata from(OpenAiApi.ChatCompletion result, RateLimit rateLimit) {
|
||||
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
|
||||
return ChatResponseMetadata.builder()
|
||||
.withId(result.id())
|
||||
.withUsage(OpenAiUsage.from(result.usage()))
|
||||
.withModel(result.model())
|
||||
.withRateLimit(rateLimit)
|
||||
.withKeyValue("created", result.created())
|
||||
.withKeyValue("system-fingerprint", result.systemFingerprint())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
|
||||
@@ -237,7 +248,7 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
|
||||
}).toList();
|
||||
|
||||
if (chatCompletion2.usage() != null) {
|
||||
return new ChatResponse(generations, OpenAiChatResponseMetadata.from(chatCompletion2));
|
||||
return new ChatResponse(generations, from(chatCompletion2));
|
||||
}
|
||||
else {
|
||||
return new ChatResponse(generations);
|
||||
@@ -253,6 +264,17 @@ public class OpenAiChatModel extends AbstractToolCallSupport<ChatCompletion> imp
|
||||
});
|
||||
}
|
||||
|
||||
private ChatResponseMetadata from(OpenAiApi.ChatCompletion result) {
|
||||
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
|
||||
return ChatResponseMetadata.builder()
|
||||
.withId(result.id())
|
||||
.withUsage(OpenAiUsage.from(result.usage()))
|
||||
.withModel(result.model())
|
||||
.withKeyValue("created", result.created())
|
||||
.withKeyValue("system-fingerprint", result.systemFingerprint())
|
||||
.build();
|
||||
}
|
||||
|
||||
private List<Message> handleToolCallRequests(List<Message> previousMessages, ChatCompletion chatCompletion) {
|
||||
|
||||
ChatCompletionMessage nativeAssistantMessage = this.extractAssistantMessage(chatCompletion);
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
*/
|
||||
package org.springframework.ai.openai;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.image.Image;
|
||||
@@ -29,12 +27,13 @@ import org.springframework.ai.image.ImageResponseMetadata;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
|
||||
import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* OpenAiImageModel is a class that implements the ImageModel interface. It provides a
|
||||
* client for calling the OpenAI image generation API.
|
||||
@@ -130,7 +129,7 @@ public class OpenAiImageModel implements ImageModel {
|
||||
new OpenAiImageGenerationMetadata(entry.revisedPrompt()));
|
||||
}).toList();
|
||||
|
||||
ImageResponseMetadata openAiImageResponseMetadata = OpenAiImageResponseMetadata.from(imageApiResponse);
|
||||
ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
|
||||
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.openai.metadata;
|
||||
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.EmptyRateLimit;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatResponseMetadata} implementation for {@literal OpenAI}.
|
||||
*
|
||||
* @author John Blum
|
||||
* @author Thomas Vitale
|
||||
* @see ChatResponseMetadata
|
||||
* @see RateLimit
|
||||
* @see Usage
|
||||
* @since 0.7.0
|
||||
*/
|
||||
public class OpenAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, model: %3$s, usage: %4$s, rateLimit: %5$s }";
|
||||
|
||||
public static OpenAiChatResponseMetadata from(OpenAiApi.ChatCompletion result) {
|
||||
Assert.notNull(result, "OpenAI ChatCompletionResult must not be null");
|
||||
OpenAiUsage usage = OpenAiUsage.from(result.usage());
|
||||
return new OpenAiChatResponseMetadata(result.id(), result.model(), usage);
|
||||
}
|
||||
|
||||
private final String id;
|
||||
|
||||
private final String model;
|
||||
|
||||
@Nullable
|
||||
private RateLimit rateLimit;
|
||||
|
||||
private final Usage usage;
|
||||
|
||||
protected OpenAiChatResponseMetadata(String id, String model, OpenAiUsage usage) {
|
||||
this(id, model, usage, null);
|
||||
}
|
||||
|
||||
protected OpenAiChatResponseMetadata(String id, String model, OpenAiUsage usage,
|
||||
@Nullable OpenAiRateLimit rateLimit) {
|
||||
this.id = id;
|
||||
this.model = model;
|
||||
this.usage = usage;
|
||||
this.rateLimit = rateLimit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getId() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModel() {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Nullable
|
||||
public RateLimit getRateLimit() {
|
||||
RateLimit rateLimit = this.rateLimit;
|
||||
return rateLimit != null ? rateLimit : new EmptyRateLimit();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Usage getUsage() {
|
||||
Usage usage = this.usage;
|
||||
return usage != null ? usage : new EmptyUsage();
|
||||
}
|
||||
|
||||
public OpenAiChatResponseMetadata withRateLimit(RateLimit rateLimit) {
|
||||
this.rateLimit = rateLimit;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return AI_METADATA_STRING.formatted(getClass().getName(), getId(), getModel(), getUsage(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.openai.metadata;
|
||||
|
||||
import org.springframework.ai.image.ImageResponseMetadata;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Objects;
|
||||
|
||||
public class OpenAiImageResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
|
||||
|
||||
private final Long created;
|
||||
|
||||
public static OpenAiImageResponseMetadata from(OpenAiImageApi.OpenAiImageResponse openAiImageResponse) {
|
||||
Assert.notNull(openAiImageResponse, "OpenAiImageResponse must not be null");
|
||||
return new OpenAiImageResponseMetadata(openAiImageResponse.created());
|
||||
}
|
||||
|
||||
protected OpenAiImageResponseMetadata(Long created) {
|
||||
this.created = created;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getCreated() {
|
||||
return this.created;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "OpenAiImageResponseMetadata{" + "created=" + created + '}';
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (!(o instanceof OpenAiImageResponseMetadata that))
|
||||
return false;
|
||||
return Objects.equals(created, that.created);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(created);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,6 +18,7 @@ package org.springframework.ai.openai.metadata.audio;
|
||||
|
||||
import org.springframework.ai.chat.metadata.EmptyRateLimit;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.model.MutableResponseMetadata;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi;
|
||||
import org.springframework.lang.Nullable;
|
||||
@@ -31,7 +32,7 @@ import java.util.HashMap;
|
||||
* @author Ahmed Yousri
|
||||
* @see RateLimit
|
||||
*/
|
||||
public class OpenAiAudioSpeechResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
|
||||
public class OpenAiAudioSpeechResponseMetadata extends MutableResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }";
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ package org.springframework.ai.openai.metadata.audio;
|
||||
|
||||
import org.springframework.ai.chat.metadata.EmptyRateLimit;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.model.MutableResponseMetadata;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi;
|
||||
import org.springframework.ai.openai.metadata.OpenAiRateLimit;
|
||||
@@ -32,7 +33,7 @@ import java.util.HashMap;
|
||||
* @since 0.8.1
|
||||
* @see RateLimit
|
||||
*/
|
||||
public class OpenAiAudioTranscriptionResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
|
||||
public class OpenAiAudioTranscriptionResponseMetadata extends MutableResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }";
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
@@ -200,11 +201,11 @@ public class PostgresMlEmbeddingModel extends AbstractEmbeddingModel implements
|
||||
}
|
||||
}
|
||||
|
||||
var metadata = new EmbeddingResponseMetadata(
|
||||
Map.of("transformer", optionsToUse.getTransformer(), "vector-type", optionsToUse.getVectorType().name(),
|
||||
"kwargs", ModelOptionsUtils.toJsonString(optionsToUse.getKwargs())));
|
||||
|
||||
return new EmbeddingResponse(data, metadata);
|
||||
Map<String, Object> embeddingMetadata = Map.of("transformer", optionsToUse.getTransformer(), "vector-type",
|
||||
optionsToUse.getVectorType().name(), "kwargs",
|
||||
ModelOptionsUtils.toJsonString(optionsToUse.getKwargs()));
|
||||
var embeddingResponseMetadata = new EmbeddingResponseMetadata("unknown", new EmptyUsage(), embeddingMetadata);
|
||||
return new EmbeddingResponse(data, embeddingResponseMetadata);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -30,6 +30,7 @@ import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
|
||||
import org.springframework.ai.postgresml.PostgresMlEmbeddingModel.VectorType;
|
||||
|
||||
import org.testcontainers.containers.PostgreSQLContainer;
|
||||
@@ -144,8 +145,21 @@ class PostgresMlEmbeddingModelIT {
|
||||
|
||||
assertThat(embeddingResponse).isNotNull();
|
||||
assertThat(embeddingResponse.getResults()).hasSize(3);
|
||||
assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(
|
||||
Map.of("transformer", "distilbert-base-uncased", "vector-type", vectorType, "kwargs", "{}"));
|
||||
|
||||
EmbeddingResponseMetadata metadata = embeddingResponse.getMetadata();
|
||||
assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys")
|
||||
.containsExactlyInAnyOrder("transformer", "vector-type", "kwargs");
|
||||
|
||||
assertThat(metadata.get("transformer").toString())
|
||||
.as("Transformer in metadata should be 'distilbert-base-uncased'")
|
||||
.isEqualTo("distilbert-base-uncased");
|
||||
|
||||
assertThat(metadata.get("vector-type").toString())
|
||||
.as("Vector type in metadata should match expected vector type")
|
||||
.isEqualTo(vectorType);
|
||||
|
||||
assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{}'").isEqualTo("{}");
|
||||
|
||||
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768);
|
||||
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
|
||||
@@ -170,8 +184,22 @@ class PostgresMlEmbeddingModelIT {
|
||||
|
||||
assertThat(embeddingResponse).isNotNull();
|
||||
assertThat(embeddingResponse.getResults()).hasSize(3);
|
||||
assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(Map.of("transformer",
|
||||
"distilbert-base-uncased", "vector-type", VectorType.PG_VECTOR.name(), "kwargs", "{}"));
|
||||
|
||||
EmbeddingResponseMetadata metadata = embeddingResponse.getMetadata();
|
||||
|
||||
assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys")
|
||||
.containsExactlyInAnyOrder("transformer", "vector-type", "kwargs");
|
||||
|
||||
assertThat(metadata.get("transformer").toString())
|
||||
.as("Transformer in metadata should be 'distilbert-base-uncased'")
|
||||
.isEqualTo("distilbert-base-uncased");
|
||||
|
||||
assertThat(metadata.get("vector-type").toString())
|
||||
.as("Vector type in metadata should match expected vector type")
|
||||
.isEqualTo(VectorType.PG_VECTOR.name());
|
||||
|
||||
assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{}'").isEqualTo("{}");
|
||||
|
||||
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768);
|
||||
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
|
||||
@@ -192,8 +220,20 @@ class PostgresMlEmbeddingModelIT {
|
||||
|
||||
assertThat(embeddingResponse).isNotNull();
|
||||
assertThat(embeddingResponse.getResults()).hasSize(3);
|
||||
assertThat(embeddingResponse.getMetadata()).containsExactlyInAnyOrderEntriesOf(Map.of("transformer",
|
||||
"intfloat/e5-small", "vector-type", VectorType.PG_ARRAY.name(), "kwargs", "{\"device\":\"cpu\"}"));
|
||||
|
||||
metadata = embeddingResponse.getMetadata();
|
||||
|
||||
assertThat(metadata.keySet()).as("Metadata should contain exactly the expected keys")
|
||||
.containsExactlyInAnyOrder("transformer", "vector-type", "kwargs");
|
||||
|
||||
assertThat(metadata.get("transformer").toString()).as("Transformer in metadata should be 'intfloat/e5-small'")
|
||||
.isEqualTo("intfloat/e5-small");
|
||||
|
||||
assertThat(metadata.get("vector-type").toString()).as("Vector type in metadata should be PG_ARRAY")
|
||||
.isEqualTo(VectorType.PG_ARRAY.name());
|
||||
|
||||
assertThat(metadata.get("kwargs").toString()).as("kwargs in metadata should be '{\"device\":\"cpu\"}'")
|
||||
.isEqualTo("{\"device\":\"cpu\"}");
|
||||
|
||||
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(384);
|
||||
|
||||
@@ -119,7 +119,7 @@ public class StabilityAiImageModel implements ImageModel {
|
||||
new StabilityAiImageGenerationMetadata(entry.finishReason(), entry.seed()));
|
||||
}).toList();
|
||||
|
||||
return new ImageResponse(imageGenerationList, ImageResponseMetadata.NULL);
|
||||
return new ImageResponse(imageGenerationList, new ImageResponseMetadata());
|
||||
}
|
||||
|
||||
private StabilityAiImageOptions convertOptions(ImageOptions runtimeOptions) {
|
||||
|
||||
@@ -24,6 +24,7 @@ import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
@@ -76,7 +77,7 @@ public class TransformersEmbeddingModelTests {
|
||||
embeddingModel.afterPropertiesSet();
|
||||
EmbeddingResponse embed = embeddingModel.embedForResponse(List.of("Hello world", "World is big"));
|
||||
assertThat(embed.getResults()).hasSize(2);
|
||||
assertThat(embed.getMetadata()).isEmpty();
|
||||
assertTrue(embed.getMetadata().isEmpty(), "Expected embed metadata to be empty, but it was not.");
|
||||
|
||||
assertThat(embed.getResults().get(0).getOutput()).hasSize(384);
|
||||
assertThat(DF.format(embed.getResults().get(0).getOutput().get(0))).isEqualTo(DF.format(-0.19744634628295898));
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package org.springframework.ai.vertexai.embedding;
|
||||
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
|
||||
public class VertexAiEmbeddingUsage implements Usage {
|
||||
|
||||
private final Integer totalTokens;
|
||||
|
||||
public VertexAiEmbeddingUsage(Integer totalTokens) {
|
||||
this.totalTokens = totalTokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getPromptTokens() {
|
||||
return 0L;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getGenerationTokens() {
|
||||
return 0L;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getTotalTokens() {
|
||||
return Long.valueOf(this.totalTokens);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import com.google.protobuf.Value;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.DocumentEmbeddingModel;
|
||||
import org.springframework.ai.embedding.DocumentEmbeddingRequest;
|
||||
@@ -35,6 +36,7 @@ import org.springframework.ai.embedding.EmbeddingResultMetadata;
|
||||
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.MultimodalInstanceBuilder;
|
||||
@@ -230,20 +232,18 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
|
||||
|
||||
String deploymentModelId = embeddingResponse.getDeployedModelId();
|
||||
|
||||
EmbeddingResponseMetadata responseMetadata = generateResponseMetadata(mergedOptions.getModel(), -1);
|
||||
|
||||
responseMetadata.put("deployment-model-id",
|
||||
Map<String, Object> metadataToUse = Map.of("deployment-model-id",
|
||||
StringUtils.hasText(deploymentModelId) ? deploymentModelId : "unknown");
|
||||
|
||||
return new EmbeddingResponse(embeddingList, generateResponseMetadata(mergedOptions.getModel(), 0));
|
||||
EmbeddingResponseMetadata responseMetadata = generateResponseMetadata(mergedOptions.getModel(), 0,
|
||||
metadataToUse);
|
||||
return new EmbeddingResponse(embeddingList, responseMetadata);
|
||||
|
||||
}
|
||||
|
||||
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) {
|
||||
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
|
||||
metadata.put("model", model);
|
||||
metadata.put("total-tokens", tokenCount);
|
||||
return metadata;
|
||||
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens,
|
||||
Map<String, Object> metadataToUse) {
|
||||
Usage usage = new VertexAiEmbeddingUsage(totalTokens);
|
||||
return new EmbeddingResponseMetadata(model, usage, metadataToUse);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -20,6 +20,7 @@ import com.google.cloud.aiplatform.v1.PredictRequest;
|
||||
import com.google.cloud.aiplatform.v1.PredictResponse;
|
||||
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
|
||||
import com.google.protobuf.Value;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
@@ -32,6 +33,7 @@ import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetai
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder;
|
||||
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
@@ -135,10 +137,11 @@ public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
|
||||
}
|
||||
}
|
||||
|
||||
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer tokenCount) {
|
||||
private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
|
||||
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
|
||||
metadata.put("model", model);
|
||||
metadata.put("total-tokens", tokenCount);
|
||||
metadata.setModel(model);
|
||||
Usage usage = new VertexAiEmbeddingUsage(totalTokens);
|
||||
metadata.setUsage(usage);
|
||||
return metadata;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,10 +33,10 @@ import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.util.MimeType;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
|
||||
@SpringBootTest(classes = VertexAiMultimodelEmbeddingModelIT.Config.class)
|
||||
@SpringBootTest(classes = VertexAiMultimodalEmbeddingModelIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
|
||||
class VertexAiMultimodelEmbeddingModelIT {
|
||||
class VertexAiMultimodalEmbeddingModelIT {
|
||||
|
||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-embeddings-api
|
||||
|
||||
@@ -68,8 +68,13 @@ class VertexAiMultimodelEmbeddingModelIT {
|
||||
.isEqualTo(embeddingRequest.getInstructions().get(1).getId());
|
||||
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(1408);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
|
||||
assertThat(embeddingResponse.getMetadata().getModel())
|
||||
.as("Model in metadata should be 'multimodalembedding@001'")
|
||||
.isEqualTo("multimodalembedding@001");
|
||||
|
||||
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens())
|
||||
.as("Total tokens in metadata should be 0")
|
||||
.isEqualTo("0");
|
||||
|
||||
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
|
||||
}
|
||||
@@ -90,8 +95,8 @@ class VertexAiMultimodelEmbeddingModelIT {
|
||||
.isEqualTo(MimeTypeUtils.TEXT_PLAIN);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
|
||||
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
|
||||
|
||||
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
|
||||
}
|
||||
@@ -113,8 +118,8 @@ class VertexAiMultimodelEmbeddingModelIT {
|
||||
.isEqualTo(MimeTypeUtils.TEXT_PLAIN);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
|
||||
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
|
||||
|
||||
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
|
||||
}
|
||||
@@ -139,8 +144,8 @@ class VertexAiMultimodelEmbeddingModelIT {
|
||||
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
|
||||
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
|
||||
|
||||
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
|
||||
}
|
||||
@@ -164,8 +169,8 @@ class VertexAiMultimodelEmbeddingModelIT {
|
||||
.isEqualTo(new MimeType("video", "mp4"));
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
|
||||
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
|
||||
|
||||
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
|
||||
}
|
||||
@@ -198,8 +203,8 @@ class VertexAiMultimodelEmbeddingModelIT {
|
||||
.isEqualTo(EmbeddingResultMetadata.ModalityType.VIDEO);
|
||||
assertThat(embeddingResponse.getResults().get(2).getOutput()).hasSize(1408);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
|
||||
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(0);
|
||||
|
||||
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
|
||||
}
|
||||
@@ -53,8 +53,12 @@ class VertexAiTextEmbeddingModelIT {
|
||||
assertThat(embeddingResponse.getResults()).hasSize(2);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(768);
|
||||
assertThat(embeddingResponse.getResults().get(1).getOutput()).hasSize(768);
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", modelName);
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 5);
|
||||
assertThat(embeddingResponse.getMetadata().getModel()).as("Model name in metadata should match expected model")
|
||||
.isEqualTo(modelName);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens())
|
||||
.as("Total tokens in metadata should be 5")
|
||||
.isEqualTo(5L);
|
||||
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(768);
|
||||
}
|
||||
|
||||
@@ -15,37 +15,6 @@
|
||||
*/
|
||||
package org.springframework.ai.vertexai.gemini;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ChatModelDescription;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractToolCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
|
||||
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
|
||||
import org.springframework.beans.factory.DisposableBean;
|
||||
import org.springframework.lang.NonNull;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.google.cloud.vertexai.VertexAI;
|
||||
@@ -63,10 +32,39 @@ import com.google.cloud.vertexai.generativeai.PartMaker;
|
||||
import com.google.cloud.vertexai.generativeai.ResponseStream;
|
||||
import com.google.protobuf.Struct;
|
||||
import com.google.protobuf.util.JsonFormat;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ChatModelDescription;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractToolCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
|
||||
import org.springframework.beans.factory.DisposableBean;
|
||||
import org.springframework.lang.NonNull;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Grogdunn
|
||||
@@ -244,8 +242,8 @@ public class VertexAiGeminiChatModel extends AbstractToolCallSupport<GenerateCon
|
||||
}
|
||||
}
|
||||
|
||||
private VertexAiChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) {
|
||||
return new VertexAiChatResponseMetadata(new VertexAiUsage(response.getUsageMetadata()));
|
||||
private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) {
|
||||
return ChatResponseMetadata.builder().withUsage(new VertexAiUsage(response.getUsageMetadata())).build();
|
||||
}
|
||||
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.vertexai.gemini.metadata;
|
||||
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.1
|
||||
*/
|
||||
public class VertexAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
|
||||
private final VertexAiUsage usage;
|
||||
|
||||
public VertexAiChatResponseMetadata(VertexAiUsage usage) {
|
||||
this.usage = usage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Usage getUsage() {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
import org.springframework.ai.chat.client.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.RequestResponseAdvisor;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.model.Content;
|
||||
@@ -127,15 +128,17 @@ public class QuestionAnswerAdvisor implements RequestResponseAdvisor {
|
||||
|
||||
@Override
|
||||
public ChatResponse adviseResponse(ChatResponse response, Map<String, Object> context) {
|
||||
response.getMetadata().put(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
|
||||
return response;
|
||||
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(response);
|
||||
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
|
||||
return chatResponseBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> adviseResponse(Flux<ChatResponse> fluxResponse, Map<String, Object> context) {
|
||||
return fluxResponse.map(cr -> {
|
||||
cr.getMetadata().put(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
|
||||
return cr;
|
||||
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(cr);
|
||||
chatResponseBuilder.withMetadata(RETRIEVED_DOCUMENTS, context.get(RETRIEVED_DOCUMENTS));
|
||||
return chatResponseBuilder.build();
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -15,40 +15,47 @@
|
||||
*/
|
||||
package org.springframework.ai.chat.metadata;
|
||||
|
||||
import org.springframework.ai.model.AbstractResponseMetadata;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Abstract Data Type (ADT) modeling common AI provider metadata returned in an AI
|
||||
* response.
|
||||
* Models common AI provider metadata returned in an AI response.
|
||||
*
|
||||
* @author John Blum
|
||||
* @author Thomas Vitale
|
||||
* @since 0.7.0
|
||||
* @author Mark Pollack
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface ChatResponseMetadata extends ResponseMetadata {
|
||||
public class ChatResponseMetadata extends AbstractResponseMetadata implements ResponseMetadata {
|
||||
|
||||
class DefaultChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
|
||||
private String id = ""; // Set to blank to preserve backward compat with previous
|
||||
// interface default methods
|
||||
|
||||
}
|
||||
private String model = "";
|
||||
|
||||
ChatResponseMetadata NULL = new DefaultChatResponseMetadata();
|
||||
private RateLimit rateLimit = new EmptyRateLimit();
|
||||
|
||||
private Usage usage = new EmptyUsage();
|
||||
|
||||
private PromptMetadata promptMetadata = PromptMetadata.empty();
|
||||
|
||||
/**
|
||||
* A unique identifier for the chat completion operation.
|
||||
* @return unique operation identifier.
|
||||
*/
|
||||
default String getId() {
|
||||
return "";
|
||||
public String getId() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
/**
|
||||
* The model that handled the request.
|
||||
* @return the model that handled the request.
|
||||
*/
|
||||
default String getModel() {
|
||||
return "";
|
||||
public String getModel() {
|
||||
return this.model;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -56,8 +63,8 @@ public interface ChatResponseMetadata extends ResponseMetadata {
|
||||
* @return AI provider specific metadata on rate limits.
|
||||
* @see RateLimit
|
||||
*/
|
||||
default RateLimit getRateLimit() {
|
||||
return new EmptyRateLimit();
|
||||
public RateLimit getRateLimit() {
|
||||
return this.rateLimit;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -65,12 +72,90 @@ public interface ChatResponseMetadata extends ResponseMetadata {
|
||||
* @return AI provider specific metadata on API usage.
|
||||
* @see Usage
|
||||
*/
|
||||
default Usage getUsage() {
|
||||
return new EmptyUsage();
|
||||
public Usage getUsage() {
|
||||
return this.usage;
|
||||
}
|
||||
|
||||
default PromptMetadata getPromptMetadata() {
|
||||
return PromptMetadata.empty();
|
||||
/**
|
||||
* Returns the prompt metadata gathered by the AI during request processing.
|
||||
* @return the prompt metadata.
|
||||
*/
|
||||
public PromptMetadata getPromptMetadata() {
|
||||
return this.promptMetadata;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private final ChatResponseMetadata chatResponseMetadata;
|
||||
|
||||
public Builder() {
|
||||
this.chatResponseMetadata = new ChatResponseMetadata();
|
||||
}
|
||||
|
||||
public Builder withMetadata(Map<String, Object> mapToCopy) {
|
||||
this.chatResponseMetadata.map.putAll(mapToCopy);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withKeyValue(String key, Object value) {
|
||||
this.chatResponseMetadata.map.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withId(String id) {
|
||||
this.chatResponseMetadata.id = id;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withModel(String model) {
|
||||
this.chatResponseMetadata.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withRateLimit(RateLimit rateLimit) {
|
||||
this.chatResponseMetadata.rateLimit = rateLimit;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withUsage(Usage usage) {
|
||||
this.chatResponseMetadata.usage = usage;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withPromptMetadata(PromptMetadata promptMetadata) {
|
||||
this.chatResponseMetadata.promptMetadata = promptMetadata;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatResponseMetadata build() {
|
||||
return this.chatResponseMetadata;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (!(o instanceof ChatResponseMetadata that))
|
||||
return false;
|
||||
return Objects.equals(id, that.id) && Objects.equals(model, that.model)
|
||||
&& Objects.equals(rateLimit, that.rateLimit) && Objects.equals(usage, that.usage)
|
||||
&& Objects.equals(promptMetadata, that.promptMetadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(id, model, rateLimit, usage, promptMetadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return AI_METADATA_STRING.formatted(getId(), getUsage(), getRateLimit());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
package org.springframework.ai.chat.model;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import org.springframework.ai.model.ModelResponse;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -40,7 +42,7 @@ public class ChatResponse implements ModelResponse<Generation> {
|
||||
* provider.
|
||||
*/
|
||||
public ChatResponse(List<Generation> generations) {
|
||||
this(generations, ChatResponseMetadata.NULL);
|
||||
this(generations, new ChatResponseMetadata());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -107,4 +109,44 @@ public class ChatResponse implements ModelResponse<Generation> {
|
||||
return Objects.hash(chatResponseMetadata, generations);
|
||||
}
|
||||
|
||||
public static ChatResponse.Builder builder() {
|
||||
return new ChatResponse.Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private List<Generation> generations;
|
||||
|
||||
private ChatResponseMetadata.Builder chatResponseMetadataBuilder;
|
||||
|
||||
private Builder() {
|
||||
this.chatResponseMetadataBuilder = ChatResponseMetadata.builder();
|
||||
}
|
||||
|
||||
public Builder from(ChatResponse other) {
|
||||
this.generations = other.generations;
|
||||
Set<Map.Entry<String, Object>> entries = other.chatResponseMetadata.entrySet();
|
||||
for (Map.Entry<String, Object> entry : entries) {
|
||||
this.chatResponseMetadataBuilder.withKeyValue(entry.getKey(), entry.getValue());
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMetadata(String key, Object value) {
|
||||
this.chatResponseMetadataBuilder.withKeyValue(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withGenerations(List<Generation> generations) {
|
||||
this.generations = generations;
|
||||
return this;
|
||||
|
||||
}
|
||||
|
||||
public ChatResponse build() {
|
||||
return new ChatResponse(generations, chatResponseMetadataBuilder.build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,24 +15,20 @@
|
||||
*/
|
||||
package org.springframework.ai.embedding;
|
||||
|
||||
import java.io.Serial;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.model.AbstractResponseMetadata;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Common AI provider metadata returned in an embedding response.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
public class EmbeddingResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
|
||||
|
||||
@Serial
|
||||
private static final long serialVersionUID = 1L;
|
||||
public class EmbeddingResponseMetadata extends AbstractResponseMetadata implements ResponseMetadata {
|
||||
|
||||
private String model;
|
||||
|
||||
@@ -42,12 +38,15 @@ public class EmbeddingResponseMetadata extends HashMap<String, Object> implement
|
||||
}
|
||||
|
||||
public EmbeddingResponseMetadata(String model, Usage usage) {
|
||||
this.model = model;
|
||||
this.usage = usage;
|
||||
this(model, usage, Map.of());
|
||||
}
|
||||
|
||||
public EmbeddingResponseMetadata(Map<String, ?> metadata) {
|
||||
super(metadata);
|
||||
public EmbeddingResponseMetadata(String model, Usage usage, Map<String, Object> metadata) {
|
||||
this.model = model;
|
||||
this.usage = usage;
|
||||
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
|
||||
this.map.put(entry.getKey(), entry.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -43,7 +43,7 @@ public class ImageResponse implements ModelResponse<ImageGeneration> {
|
||||
* provider.
|
||||
*/
|
||||
public ImageResponse(List<ImageGeneration> generations) {
|
||||
this(generations, ImageResponseMetadata.NULL);
|
||||
this(generations, new ImageResponseMetadata());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
*/
|
||||
package org.springframework.ai.image;
|
||||
|
||||
import org.springframework.ai.model.MutableResponseMetadata;
|
||||
import org.springframework.ai.model.ResponseMetadata;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -28,16 +29,20 @@ import java.util.HashMap;
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface ImageResponseMetadata extends ResponseMetadata {
|
||||
public class ImageResponseMetadata extends MutableResponseMetadata {
|
||||
|
||||
class DefaultImageResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
|
||||
private Long created;
|
||||
|
||||
public ImageResponseMetadata() {
|
||||
this.created = System.currentTimeMillis();
|
||||
}
|
||||
|
||||
ImageResponseMetadata NULL = new DefaultImageResponseMetadata();
|
||||
public ImageResponseMetadata(Long created) {
|
||||
this.created = created;
|
||||
}
|
||||
|
||||
default Long getCreated() {
|
||||
return System.currentTimeMillis();
|
||||
public Long getCreated() {
|
||||
return this.created;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package org.springframework.ai.model;
|
||||
|
||||
import io.micrometer.common.lang.NonNull;
|
||||
import io.micrometer.common.lang.Nullable;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class AbstractResponseMetadata {
|
||||
|
||||
protected static final String AI_METADATA_STRING = "{ id: %2$s, usage: %3$s, rateLimit: %4$s }";
|
||||
|
||||
protected final Map<String, Object> map = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* Gets an entry from the context. Returns {@code null} when entry is not present.
|
||||
* @param key key
|
||||
* @param <T> value type
|
||||
* @return entry or {@code null} if not present
|
||||
*/
|
||||
@Nullable
|
||||
public <T> T get(String key) {
|
||||
return (T) this.map.get(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets an entry from the context. Throws exception when entry is not present.
|
||||
* @param key key
|
||||
* @param <T> value type
|
||||
* @return entry
|
||||
* @throws IllegalArgumentException if not present
|
||||
*/
|
||||
@NonNull
|
||||
public <T> T getRequired(Object key) {
|
||||
T object = (T) this.map.get(key);
|
||||
if (object == null) {
|
||||
throw new IllegalArgumentException("Context does not have an entry for key [" + key + "]");
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if context contains a key.
|
||||
* @param key key
|
||||
* @return {@code true} when the context contains the entry with the given key
|
||||
*/
|
||||
public boolean containsKey(Object key) {
|
||||
return this.map.containsKey(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element or default if not present.
|
||||
* @param key key
|
||||
* @param defaultObject default object to return
|
||||
* @param <T> value type
|
||||
* @return object or default if not present
|
||||
*/
|
||||
public <T> T getOrDefault(Object key, T defaultObject) {
|
||||
return (T) this.map.getOrDefault(key, defaultObject);
|
||||
}
|
||||
|
||||
public Set<Map.Entry<String, Object>> entrySet() {
|
||||
return Collections.unmodifiableMap(this.map).entrySet();
|
||||
}
|
||||
|
||||
public Set<String> keySet() {
|
||||
return Collections.unmodifiableSet(this.map.keySet());
|
||||
}
|
||||
|
||||
public boolean isEmpty() {
|
||||
return this.map.isEmpty();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package org.springframework.ai.model;
|
||||
|
||||
import io.micrometer.common.lang.NonNull;
|
||||
import io.micrometer.common.lang.Nullable;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Function;
|
||||
|
||||
public class MutableResponseMetadata implements ResponseMetadata {
|
||||
|
||||
private final Map<String, Object> map = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* Puts an element to the context.
|
||||
* @param key key
|
||||
* @param object value
|
||||
* @param <T> value type
|
||||
* @return this for chaining
|
||||
*/
|
||||
public <T> MutableResponseMetadata put(String key, T object) {
|
||||
this.map.put(key, object);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets an entry from the context. Returns {@code null} when entry is not present.
|
||||
* @param key key
|
||||
* @param <T> value type
|
||||
* @return entry or {@code null} if not present
|
||||
*/
|
||||
@Override
|
||||
@Nullable
|
||||
public <T> T get(String key) {
|
||||
return (T) this.map.get(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes an entry from the context.
|
||||
* @param key key by which to remove an entry
|
||||
* @return the previous value associated with the key, or null if there was no mapping
|
||||
* for the key
|
||||
*/
|
||||
public Object remove(Object key) {
|
||||
return this.map.remove(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets an entry from the context. Throws exception when entry is not present.
|
||||
* @param key key
|
||||
* @param <T> value type
|
||||
* @throws IllegalArgumentException if not present
|
||||
* @return entry
|
||||
*/
|
||||
@Override
|
||||
@NonNull
|
||||
public <T> T getRequired(Object key) {
|
||||
T object = (T) this.map.get(key);
|
||||
if (object == null) {
|
||||
throw new IllegalArgumentException("Context does not have an entry for key [" + key + "]");
|
||||
}
|
||||
return object;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if context contains a key.
|
||||
* @param key key
|
||||
* @return {@code true} when the context contains the entry with the given key
|
||||
*/
|
||||
@Override
|
||||
public boolean containsKey(Object key) {
|
||||
return this.map.containsKey(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element or default if not present.
|
||||
* @param key key
|
||||
* @param defaultObject default object to return
|
||||
* @param <T> value type
|
||||
* @return object or default if not present
|
||||
*/
|
||||
@Override
|
||||
public <T> T getOrDefault(Object key, T defaultObject) {
|
||||
return (T) this.map.getOrDefault(key, defaultObject);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<Map.Entry<String, Object>> entrySet() {
|
||||
return Collections.unmodifiableMap(this.map).entrySet();
|
||||
}
|
||||
|
||||
public Set<String> keySet() {
|
||||
return Collections.unmodifiableSet(this.map.keySet());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEmpty() {
|
||||
return this.map.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an element or calls a mapping function if entry not present. The function
|
||||
* will insert the value to the map.
|
||||
* @param key key
|
||||
* @param mappingFunction mapping function
|
||||
* @param <T> value type
|
||||
* @return object or one derived from the mapping function if not present
|
||||
*/
|
||||
public <T> T computeIfAbsent(String key, Function<Object, ? extends T> mappingFunction) {
|
||||
return (T) this.map.computeIfAbsent(key, mappingFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the entries from the context.
|
||||
*/
|
||||
public void clear() {
|
||||
this.map.clear();
|
||||
}
|
||||
|
||||
public Map<String, Object> getRawMap() {
|
||||
return map;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -15,18 +15,77 @@
|
||||
*/
|
||||
package org.springframework.ai.model;
|
||||
|
||||
import io.micrometer.common.lang.NonNull;
|
||||
import io.micrometer.common.lang.Nullable;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
/**
|
||||
* Interface representing metadata associated with an AI model's response. This interface
|
||||
* is designed to provide additional information about the generative response from an AI
|
||||
* model, including processing details and model-specific data. It serves as a value
|
||||
* object within the core domain, enhancing the understanding and management of AI model
|
||||
* responses in various applications.
|
||||
* Interface representing metadata associated with an AI model's response.
|
||||
*
|
||||
* @author Mark Pollack
|
||||
* @since 0.8.0
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface ResponseMetadata extends Map<String, Object> {
|
||||
public interface ResponseMetadata {
|
||||
|
||||
/**
|
||||
* Gets an entry from the context. Returns {@code null} when entry is not present.
|
||||
* @param key key
|
||||
* @param <T> value type
|
||||
* @return entry or {@code null} if not present
|
||||
*/
|
||||
@Nullable
|
||||
<T> T get(String key);
|
||||
|
||||
/**
|
||||
* Gets an entry from the context. Throws exception when entry is not present.
|
||||
* @param key key
|
||||
* @param <T> value type
|
||||
* @throws IllegalArgumentException if not present
|
||||
* @return entry
|
||||
*/
|
||||
@NonNull
|
||||
<T> T getRequired(Object key);
|
||||
|
||||
/**
|
||||
* Checks if context contains a key.
|
||||
* @param key key
|
||||
* @return {@code true} when the context contains the entry with the given key
|
||||
*/
|
||||
boolean containsKey(Object key);
|
||||
|
||||
/**
|
||||
* Returns an element or default if not present.
|
||||
* @param key key
|
||||
* @param defaultObject default object to return
|
||||
* @param <T> value type
|
||||
* @return object or default if not present
|
||||
*/
|
||||
<T> T getOrDefault(Object key, T defaultObject);
|
||||
|
||||
/**
|
||||
* Returns an element or default if not present.
|
||||
* @param key key
|
||||
* @param defaultObjectSupplier supplier for default object to return
|
||||
* @param <T> value type
|
||||
* @return object or default if not present
|
||||
* @since 1.11.0
|
||||
*/
|
||||
default <T> T getOrDefault(String key, Supplier<T> defaultObjectSupplier) {
|
||||
T value = get(key);
|
||||
return value != null ? value : defaultObjectSupplier.get();
|
||||
}
|
||||
|
||||
Set<Map.Entry<String, Object>> entrySet();
|
||||
|
||||
public Set<String> keySet();
|
||||
|
||||
/**
|
||||
* Returns {@code true} if this map contains no key-value mappings.
|
||||
* @return {@code true} if this map contains no key-value mappings
|
||||
*/
|
||||
boolean isEmpty();
|
||||
|
||||
}
|
||||
|
||||
@@ -136,9 +136,7 @@ public abstract class AbstractFunctionCallSupport<Msg, Req, Resp> {
|
||||
|
||||
// The chat completion tool call requires the complete conversation
|
||||
// history. Including the initial user message.
|
||||
List<Msg> conversationHistory = new ArrayList<>();
|
||||
|
||||
conversationHistory.addAll(this.doGetUserMessages(request));
|
||||
List<Msg> conversationHistory = new ArrayList<>(this.doGetUserMessages(request));
|
||||
|
||||
Msg responseMessage = this.doGetToolResponseMessage(response);
|
||||
|
||||
@@ -164,9 +162,7 @@ public abstract class AbstractFunctionCallSupport<Msg, Req, Resp> {
|
||||
|
||||
// The chat completion tool call requires the complete conversation
|
||||
// history. Including the initial user message.
|
||||
List<Msg> conversationHistory = new ArrayList<>();
|
||||
|
||||
conversationHistory.addAll(this.doGetUserMessages(request));
|
||||
List<Msg> conversationHistory = new ArrayList<>(this.doGetUserMessages(request));
|
||||
|
||||
Msg responseMessage = this.doGetToolResponseMessage(resp);
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -58,8 +57,7 @@ public class ChatClientResponseEntityTests {
|
||||
@Test
|
||||
public void responseEntityTest() {
|
||||
|
||||
ChatResponseMetadata metadata = new DefaultChatResponseMetadata();
|
||||
metadata.put("key1", "value1");
|
||||
ChatResponseMetadata metadata = ChatResponseMetadata.builder().withKeyValue("key1", "value1").build();
|
||||
|
||||
var chatResponse = new ChatResponse(List.of(new Generation("""
|
||||
{"name":"John", "age":30}
|
||||
@@ -75,7 +73,7 @@ public class ChatClientResponseEntityTests {
|
||||
.responseEntity(MyBean.class);
|
||||
|
||||
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
|
||||
assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1");
|
||||
assertThat(responseEntity.getResponse().getMetadata().get("key1").toString()).isEqualTo("value1");
|
||||
|
||||
assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));
|
||||
|
||||
|
||||
@@ -112,8 +112,8 @@ public class VertexAiTextEmbeddingModelAutoConfigurationIT {
|
||||
.isEqualTo(EmbeddingResultMetadata.ModalityType.TEXT);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1408);
|
||||
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 0);
|
||||
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo("multimodalembedding@001");
|
||||
assertThat(embeddingResponse.getMetadata().getUsage()).isEqualTo(0);
|
||||
|
||||
assertThat(multiModelEmbeddingModel.dimensions()).isEqualTo(1408);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user