From ded9facfe5900f4196da1c23c56b397f1e16afdd Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Mon, 17 Feb 2025 14:41:27 +0000 Subject: [PATCH] Remove deprecations from 1.0.0-M6 - Remove deprecations from models, vector stores and usage - Deprecations from FunctionCallback and ObservationContext/Convention will be in a separate PR Models updates - Remove AbstractToolCallSupport from the models which use ToolCallingManager - Remove deprecated constructors and their usage - Remove FunctionCallbackResolver and FunctionCallbacks usage in the models - Add back deprecations for VectorStoreChatMemoryAdvisor until builder is fixed - Update OpenAiPaymentTransactionIT to use ToolCallbackResolver in config Signed-off-by: Ilayaperumal Gopinathan --- .../ai/anthropic/AnthropicChatModel.java | 153 +------ .../AnthropicChatModelObservationIT.java | 5 +- .../ai/azure/openai/AzureOpenAiChatModel.java | 96 +---- .../converse/BedrockProxyChatModel.java | 177 +-------- .../converse/BedrockConverseChatClientIT.java | 5 +- .../BedrockConverseTestConfiguration.java | 4 +- .../BedrockConverseUsageAggregationTests.java | 20 +- .../converse/BedrockProxyChatModelIT.java | 18 +- .../BedrockProxyChatModelObservationIT.java | 14 +- .../client/BedrockNovaChatClientIT.java | 3 +- .../BedrockConverseChatModelMain.java | 4 +- .../BedrockConverseChatModelMain3.java | 13 +- .../ai/mistralai/MistralAiChatModel.java | 100 +---- .../ai/mistralai/api/MistralAiApi.java | 6 - .../MistralAiChatModelObservationIT.java | 20 +- .../ai/mistralai/MistralAiRetryTests.java | 18 +- .../mistralai/MistralAiTestConfiguration.java | 6 +- .../ai/mistralai/api/MistralAiApiIT.java | 6 +- .../ai/ollama/OllamaChatModel.java | 69 +--- .../ai/ollama/OllamaChatModelTests.java | 8 +- .../ai/openai/OpenAiChatModel.java | 133 +------ .../ai/openai/api/OpenAiApi.java | 91 ----- .../ai/openai/api/OpenAiAudioApi.java | 226 ----------- .../ai/openai/api/OpenAiImageApi.java | 52 --- .../ai/openai/api/OpenAiModerationApi.java | 27 -- .../ai/openai/ChatCompletionRequestTests.java | 32 +- .../ai/openai/OpenAiTestConfiguration.java | 20 +- .../ai/openai/api/OpenAiApiIT.java | 2 +- .../api/tool/OpenAiApiToolFunctionCallIT.java | 2 +- .../openai/chat/MessageTypeContentTests.java | 2 +- ...penAiChatModelAdditionalHttpHeadersIT.java | 2 +- .../OpenAiChatModelFunctionCallingIT.java | 4 +- .../chat/OpenAiChatModelNoOpApiKeysIT.java | 2 +- .../chat/OpenAiChatModelObservationIT.java | 7 +- .../chat/OpenAiChatModelProxyToolCallsIT.java | 372 ------------------ .../chat/OpenAiChatModelResponseFormatIT.java | 4 +- ...hatModelWithChatResponseMetadataTests.java | 11 +- .../chat/OpenAiCompatibleChatModelIT.java | 15 +- .../chat/OpenAiPaymentTransactionIT.java | 78 +++- .../ai/openai/chat/OpenAiRetryTests.java | 7 +- .../proxy/DeepSeekWithOpenAiChatModelIT.java | 7 +- .../chat/proxy/GroqWithOpenAiChatModelIT.java | 7 +- .../proxy/NvidiaWithOpenAiChatModelIT.java | 8 +- .../PerplexityWithOpenAiChatModelIT.java | 14 +- .../OpenAiEmbeddingModelObservationIT.java | 2 +- .../transformer/MetadataTransformerIT.java | 5 +- ...iGeminiPaymentTransactionDeprecatedIT.java | 235 ----------- .../ai/chat/metadata/DefaultUsage.java | 16 +- .../ai/chat/metadata/Usage.java | 5 - .../ai/chat/metadata/DefaultUsageTests.java | 38 +- .../tests/tool/ToolCallingManagerTests.java | 12 - .../mistralai/MistralAiChatProperties.java | 2 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 7 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 7 +- .../tool/FunctionCallWithFunctionBeanIT.java | 6 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 5 +- ...alContainerConnectionDetailsFactoryIT.java | 2 +- .../vectorstore/BasicAuthChromaWhereIT.java | 2 +- .../vectorstore/ChromaVectorStoreIT.java | 2 +- .../ChromaVectorStoreObservationIT.java | 2 +- .../TokenSecuredChromaWhereIT.java | 2 +- .../ElasticsearchVectorStoreIT.java | 2 +- ...ElasticsearchVectorStoreObservationIT.java | 2 +- .../hanadb/HanaCloudVectorStoreIT.java | 2 +- .../hanadb/HanaVectorStoreObservationIT.java | 2 +- .../mariadb/MariaDBStoreCustomNamesIT.java | 2 +- .../vectorstore/mariadb/MariaDBStoreIT.java | 2 +- .../mariadb/MariaDBStoreObservationIT.java | 2 +- .../MilvusVectorStoreCustomFieldNamesIT.java | 2 +- .../milvus/MilvusVectorStoreIT.java | 2 +- .../MilvusVectorStoreObservationIT.java | 2 +- .../atlas/MongoDBAtlasVectorStoreIT.java | 2 +- .../MongoDbVectorStoreObservationIT.java | 2 +- .../vectorstore/neo4j/Neo4jVectorStoreIT.java | 2 +- .../neo4j/Neo4jVectorStoreObservationIT.java | 2 +- .../opensearch/OpenSearchVectorStoreIT.java | 2 +- .../OpenSearchVectorStoreObservationIT.java | 2 +- .../pgvector/PgVectorStoreCustomNamesIT.java | 2 +- .../vectorstore/pgvector/PgVectorStoreIT.java | 2 +- .../pgvector/PgVectorStoreObservationIT.java | 2 +- 80 files changed, 296 insertions(+), 1962 deletions(-) delete mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java delete mode 100644 models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 96ffcbaf6..cd02998ac 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -30,13 +30,6 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.model.tool.LegacyToolCallingManager; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.ai.model.tool.ToolExecutionResult; -import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.ai.util.json.JsonParser; -import org.springframework.lang.Nullable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -59,7 +52,6 @@ import org.springframework.ai.chat.metadata.DefaultUsage; import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.metadata.UsageUtils; -import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -72,10 +64,12 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.util.json.JsonParser; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -94,7 +88,7 @@ import org.springframework.util.StringUtils; * @author Alexandros Pappas * @since 1.0.0 */ -public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel { +public class AnthropicChatModel implements ChatModel { public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getValue(); @@ -135,111 +129,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - /** - * Construct a new {@link AnthropicChatModel} instance. - * @param anthropicApi the lower-level API for the Anthropic service. - * @deprecated Use {@link AnthropicChatModel.Builder}. - */ - @Deprecated - public AnthropicChatModel(AnthropicApi anthropicApi) { - this(anthropicApi, - AnthropicChatOptions.builder() - .model(DEFAULT_MODEL_NAME) - .maxTokens(DEFAULT_MAX_TOKENS) - .temperature(DEFAULT_TEMPERATURE) - .build()); - } - - /** - * Construct a new {@link AnthropicChatModel} instance. - * @param anthropicApi the lower-level API for the Anthropic service. - * @param defaultOptions the default options used for the chat completion requests. - * @deprecated Use {@link AnthropicChatModel.Builder}. - */ - @Deprecated - public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) { - this(anthropicApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE); - } - - /** - * Construct a new {@link AnthropicChatModel} instance. - * @param anthropicApi the lower-level API for the Anthropic service. - * @param defaultOptions the default options used for the chat completion requests. - * @param retryTemplate the retry template used to retry the Anthropic API calls. - * @deprecated Use {@link AnthropicChatModel.Builder}. - */ - @Deprecated - public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, - RetryTemplate retryTemplate) { - this(anthropicApi, defaultOptions, retryTemplate, null); - } - - /** - * Construct a new {@link AnthropicChatModel} instance. - * @param anthropicApi the lower-level API for the Anthropic service. - * @param defaultOptions the default options used for the chat completion requests. - * @param retryTemplate the retry template used to retry the Anthropic API calls. - * @param functionCallbackResolver the function callback resolver used to resolve the - * function by its name. - * @deprecated Use {@link AnthropicChatModel.Builder}. - */ - @Deprecated - public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, - RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver) { - this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, List.of()); - } - - /** - * Construct a new {@link AnthropicChatModel} instance. - * @param anthropicApi the lower-level API for the Anthropic service. - * @param defaultOptions the default options used for the chat completion requests. - * @param retryTemplate the retry template used to retry the Anthropic API calls. - * @param functionCallbackResolver the function callback resolver used to resolve the - * function by its name. - * @param toolFunctionCallbacks the tool function callbacks used to handle the tool - * calls. - * @deprecated Use {@link AnthropicChatModel.Builder}. - */ - @Deprecated - public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, - RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver, - List toolFunctionCallbacks) { - this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, toolFunctionCallbacks, - ObservationRegistry.NOOP); - } - - /** - * Construct a new {@link AnthropicChatModel} instance. - * @param anthropicApi the lower-level API for the Anthropic service. - * @param defaultOptions the default options used for the chat completion requests. - * @param retryTemplate the retry template used to retry the Anthropic API calls. - * @param functionCallbackResolver the function callback resolver used to resolve the - * function by its name. - * @param toolFunctionCallbacks the tool function callbacks used to handle the tool - * calls. - * @deprecated Use {@link AnthropicChatModel.Builder}. - */ - @Deprecated - public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, - RetryTemplate retryTemplate, @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks, ObservationRegistry observationRegistry) { - this(anthropicApi, defaultOptions, - LegacyToolCallingManager.builder() - .functionCallbackResolver(functionCallbackResolver) - .functionCallbacks(toolFunctionCallbacks) - .build(), - retryTemplate, observationRegistry); - logger.warn("This constructor is deprecated and will be removed in the next milestone. " - + "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); - } - public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { - // We do not pass the 'defaultOptions' to the AbstractToolSupport, - // because it modifies them. We are using ToolCallingManager instead, - // so we just pass empty options here. - super(null, AnthropicChatOptions.builder().build(), List.of()); Assert.notNull(anthropicApi, "anthropicApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); @@ -488,10 +380,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, AnthropicChatOptions.class); } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, - AnthropicChatOptions.class); - } else { runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, AnthropicChatOptions.class); @@ -648,10 +536,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; - private FunctionCallbackResolver functionCallbackResolver; - - private List toolCallbacks; - private ToolCallingManager toolCallingManager; private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -679,18 +563,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM return this; } - @Deprecated - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - @Deprecated - public Builder toolCallbacks(List toolCallbacks) { - this.toolCallbacks = toolCallbacks; - return this; - } - public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; @@ -698,22 +570,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM public AnthropicChatModel build() { if (toolCallingManager != null) { - Assert.isNull(functionCallbackResolver, - "functionCallbackResolver cannot be set when toolCallingManager is set"); - Assert.isNull(toolCallbacks, "toolCallbacks cannot be set when toolCallingManager is set"); - return new AnthropicChatModel(anthropicApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry); } - if (functionCallbackResolver != null) { - Assert.isNull(toolCallingManager, - "toolCallingManager cannot be set when functionCallbackResolver is set"); - List toolCallbacks = this.toolCallbacks != null ? this.toolCallbacks : List.of(); - - return new AnthropicChatModel(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, - toolCallbacks, observationRegistry); - } - return new AnthropicChatModel(anthropicApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, observationRegistry); } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java index 20a8f037a..ecbfcf6e1 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelObservationIT.java @@ -34,6 +34,8 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.beans.factory.annotation.Autowired; @@ -171,8 +173,7 @@ public class AnthropicChatModelObservationIT { public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi, TestObservationRegistry observationRegistry) { return new AnthropicChatModel(anthropicApi, AnthropicChatOptions.builder().build(), - RetryTemplate.defaultInstance(), new DefaultFunctionCallbackResolver(), List.of(), - observationRegistry); + ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 0ac363b2a..ff60d3b1b 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -73,7 +73,6 @@ import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.metadata.UsageUtils; -import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -86,16 +85,12 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.tool.LegacyToolCallingManager; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -120,7 +115,7 @@ import org.springframework.util.CollectionUtils; * @see com.azure.ai.openai.OpenAIClient * @since 1.0.0 */ -public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel { +public class AzureOpenAiChatModel implements ChatModel { private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModel.class); @@ -162,53 +157,8 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha */ private final ToolCallingManager toolCallingManager; - @Deprecated - public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) { - this(openAIClientBuilder, - AzureOpenAiChatOptions.builder() - .deploymentName(DEFAULT_DEPLOYMENT_NAME) - .temperature(DEFAULT_TEMPERATURE) - .build()); - } - - @Deprecated - public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options) { - this(openAIClientBuilder, options, null); - } - - @Deprecated - public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, - FunctionCallbackResolver functionCallbackResolver) { - this(openAIClientBuilder, options, functionCallbackResolver, List.of()); - } - - @Deprecated - public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks) { - this(openAIClientBuilder, options, functionCallbackResolver, toolFunctionCallbacks, ObservationRegistry.NOOP); - } - - @Deprecated - public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks, ObservationRegistry observationRegistry) { - this(openAIClientBuilder, options, - LegacyToolCallingManager.builder() - .functionCallbackResolver(functionCallbackResolver) - .functionCallbacks(toolFunctionCallbacks) - .build(), - observationRegistry); - logger.warn("This constructor is deprecated and will be removed in the next milestone. " - + "Please use the AzureOpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); - } - public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry) { - // We do not pass the 'defaultOptions' to the AbstractToolSupport, - // because it modifies them. We are using ToolCallingManager instead, - // so we just pass empty options here. - super(null, AzureOpenAiChatOptions.builder().build(), List.of()); Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); @@ -534,10 +484,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha if (prompt.getOptions() != null) { AzureOpenAiChatOptions updatedRuntimeOptions; - if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, - FunctionCallingOptions.class, AzureOpenAiChatOptions.class); - } if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, AzureOpenAiChatOptions.class); @@ -668,10 +614,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, AzureOpenAiChatOptions.class); } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, - AzureOpenAiChatOptions.class); - } else { runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, AzureOpenAiChatOptions.class); @@ -978,10 +920,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha private ToolCallingManager toolCallingManager; - private FunctionCallbackResolver functionCallbackResolver; - - private List toolFunctionCallbacks; - private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; private Builder() { @@ -1002,18 +940,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha return this; } - @Deprecated - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - @Deprecated - public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { - this.toolFunctionCallbacks = toolFunctionCallbacks; - return this; - } - public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; @@ -1021,29 +947,9 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha public AzureOpenAiChatModel build() { if (toolCallingManager != null) { - Assert.isNull(functionCallbackResolver, - "functionCallbackResolver cannot be set when toolCallingManager is set"); - Assert.isNull(toolFunctionCallbacks, - "toolFunctionCallbacks cannot be set when toolCallingManager is set"); - return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, toolCallingManager, observationRegistry); } - - if (functionCallbackResolver != null) { - Assert.isNull(toolCallingManager, - "toolCallingManager cannot be set when functionCallbackResolver is set"); - List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks - : List.of(); - - return new Builder().openAIClientBuilder(openAIClientBuilder) - .defaultOptions(defaultOptions) - .functionCallbackResolver(functionCallbackResolver) - .toolFunctionCallbacks(toolCallbacks) - .observationRegistry(observationRegistry) - .build(); - } - return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, observationRegistry); } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 3098f4632..50737ca32 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -82,7 +82,6 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; -import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -95,10 +94,6 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.Media; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.tool.LegacyToolCallingManager; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.model.tool.ToolExecutionResult; @@ -135,7 +130,7 @@ import org.springframework.util.StringUtils; * @author Jihoon Kim * @since 1.0.0 */ -public class BedrockProxyChatModel extends AbstractToolCallSupport implements ChatModel { +public class BedrockProxyChatModel implements ChatModel { private static final Logger logger = LoggerFactory.getLogger(BedrockProxyChatModel.class); @@ -161,33 +156,10 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch */ private ChatModelObservationConvention observationConvention; - /** - * @deprecated Use - * {@link #BedrockProxyChatModel(BedrockRuntimeClient, BedrockRuntimeAsyncClient, ToolCallingChatOptions, ObservationRegistry, ToolCallingManager)} - * instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, - BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions, - FunctionCallbackResolver functionCallbackResolver, List toolFunctionCallbacks, - ObservationRegistry observationRegistry) { - - this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, from(defaultOptions), observationRegistry, - LegacyToolCallingManager.builder() - .functionCallbackResolver(functionCallbackResolver) - .functionCallbacks(toolFunctionCallbacks) - .build()); - } - public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) { - // We do not pass the 'defaultOptions' to the AbstractToolSupport, - // because it modifies them. We are using ToolCallingManager instead, - // so we just pass empty options here. - super(null, FunctionCallingOptions.builder().build(), List.of()); - Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null"); Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null"); Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); @@ -199,21 +171,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch this.toolCallingManager = toolCallingManager; } - @Deprecated - private static ToolCallingChatOptions from(FunctionCallingOptions options) { - return ToolCallingChatOptions.builder() - .model(options.getModel()) - .maxTokens(options.getMaxTokens()) - .stopSequences(options.getStopSequences()) - .temperature(options.getTemperature()) - .topP(options.getTopP()) - .toolCallbacks(options.getFunctionCallbacks()) - .toolNames(options.getFunctions()) - .internalToolExecutionEnabled(options.getProxyToolCalls() != null ? !options.getProxyToolCalls() : false) - .toolContext(options.getToolContext()) - .build(); - } - private static ToolCallingChatOptions from(ChatOptions options) { return ToolCallingChatOptions.builder() .model(options.getModel()) @@ -295,9 +252,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { runtimeOptions = toolCallingChatOptions.copy(); } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - runtimeOptions = from(functionCallingOptions); - } else { runtimeOptions = from(prompt.getOptions()); } @@ -804,10 +758,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch private ToolCallingChatOptions defaultOptions = ToolCallingChatOptions.builder().build(); - private FunctionCallbackResolver functionCallbackResolver; - - private List toolFunctionCallbacks; - private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; private ChatModelObservationConvention customObservationConvention; @@ -824,160 +774,47 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch return this; } - /** - * @deprecated Use {@link #credentialsProvider(AwsCredentialsProvider)} instead. - */ - @Deprecated - public Builder withCredentialsProvider(AwsCredentialsProvider credentialsProvider) { - Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null."); - this.credentialsProvider = credentialsProvider; - return this; - } - public Builder credentialsProvider(AwsCredentialsProvider credentialsProvider) { Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null."); this.credentialsProvider = credentialsProvider; return this; } - /** - * @deprecated Use {@link #region(Region)} instead. - */ - @Deprecated - public Builder withRegion(Region region) { - Assert.notNull(region, "'region' must not be null."); - this.region = region; - return this; - } - public Builder region(Region region) { Assert.notNull(region, "'region' must not be null."); this.region = region; return this; } - /** - * @deprecated Use {@link #timeout(Duration)} instead. - */ - @Deprecated - public Builder withTimeout(Duration timeout) { - Assert.notNull(timeout, "'timeout' must not be null."); - this.timeout = timeout; - return this; - } - public Builder timeout(Duration timeout) { Assert.notNull(timeout, "'timeout' must not be null."); this.timeout = timeout; return this; } - /** - * @deprecated Use {@link #defaultOptions(ToolCallingChatOptions)} instead. - */ - @Deprecated - public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) { - Assert.notNull(defaultOptions, "'defaultOptions' must not be null."); - return this.defaultOptions(ToolCallingChatOptions.builder() - .model(defaultOptions.getModel()) - .maxTokens(defaultOptions.getMaxTokens()) - .stopSequences(defaultOptions.getStopSequences()) - .temperature(defaultOptions.getTemperature()) - .topP(defaultOptions.getTopP()) - .toolCallbacks(defaultOptions.getFunctionCallbacks()) - .toolNames(defaultOptions.getFunctions()) - .internalToolExecutionEnabled( - defaultOptions.getProxyToolCalls() != null ? !defaultOptions.getProxyToolCalls() : false) - .toolContext(defaultOptions.getToolContext()) - .build()); - } - public Builder defaultOptions(ToolCallingChatOptions defaultOptions) { Assert.notNull(defaultOptions, "'defaultOptions' must not be null."); this.defaultOptions = defaultOptions; return this; } - /** - * @deprecated To be removed after M6 - */ - @Deprecated - public Builder withFunctionCallbackContext(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - /** - * @deprecated To be removed after M6 - */ - @Deprecated - public Builder withToolFunctionCallbacks(List toolFunctionCallbacks) { - this.toolFunctionCallbacks = toolFunctionCallbacks; - return this; - } - - /** - * @deprecated Use {@link #observationRegistry(ObservationRegistry)} instead. - */ - @Deprecated - public Builder withObservationRegistry(ObservationRegistry observationRegistry) { - Assert.notNull(observationRegistry, "'observationRegistry' must not be null."); - this.observationRegistry = observationRegistry; - return this; - } - public Builder observationRegistry(ObservationRegistry observationRegistry) { Assert.notNull(observationRegistry, "'observationRegistry' must not be null."); this.observationRegistry = observationRegistry; return this; } - /** - * @deprecated Use - * {@link #customObservationConvention(ChatModelObservationConvention)} instead. - */ - @Deprecated - public Builder withCustomObservationConvention(ChatModelObservationConvention observationConvention) { - Assert.notNull(observationConvention, "'observationConvention' must not be null."); - this.customObservationConvention = observationConvention; - return this; - } - public Builder customObservationConvention(ChatModelObservationConvention observationConvention) { Assert.notNull(observationConvention, "'observationConvention' must not be null."); this.customObservationConvention = observationConvention; return this; } - /** - * @deprecated Use {@link #bedrockRuntimeClient(BedrockRuntimeClient)} instead. - */ - @Deprecated - public Builder withBedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) { - this.bedrockRuntimeClient = bedrockRuntimeClient; - return this; - } - public Builder bedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) { this.bedrockRuntimeClient = bedrockRuntimeClient; return this; } - /** - * @deprecated Use {@link #bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient)} - * instead. - */ - @Deprecated - public Builder withBedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) { - this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; - return this; - } - public Builder bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) { this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient; return this; @@ -1013,23 +850,11 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch BedrockProxyChatModel bedrockProxyChatModel = null; if (this.toolCallingManager != null) { - Assert.isNull(functionCallbackResolver, - "functionCallbackResolver cannot be set when toolCallingManager is set"); - Assert.isNull(toolFunctionCallbacks, - "toolFunctionCallbacks cannot be set when toolCallingManager is set"); - bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, this.toolCallingManager); } - else if (this.functionCallbackResolver != null) { - Assert.isNull(toolCallingManager, - "toolCallingManager cannot be set when functionCallbackResolver is set"); - bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, - this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackResolver, - this.toolFunctionCallbacks, this.observationRegistry); - } else { bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index cfbe4fd1b..7e0149374 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -36,7 +36,6 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; @@ -367,7 +366,7 @@ class BedrockConverseChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() - .options(FunctionCallingOptions.builder().model(modelName).build()) + .options(ToolCallingChatOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) .call() @@ -389,7 +388,7 @@ class BedrockConverseChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to - .options(FunctionCallingOptions.builder().model(modelName).build()) + .options(ToolCallingChatOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) .call() .content(); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java index 7f2fd2c97..a42a4beec 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java @@ -21,7 +21,7 @@ import java.time.Duration; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; @@ -42,7 +42,7 @@ public class BedrockConverseTestConfiguration { .region(Region.US_EAST_1) // .region(Region.US_EAST_1) .timeout(Duration.ofSeconds(120)) - .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) + .defaultOptions(ToolCallingChatOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java index 25a44b63a..09d7c83d6 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java @@ -16,11 +16,7 @@ package org.springframework.ai.bedrock.converse; -import java.util.List; - -import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -40,8 +36,9 @@ import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage; import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.isA; @@ -61,8 +58,10 @@ public class BedrockConverseUsageAggregationTests { @BeforeEach public void beforeEach() { - this.chatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, - FunctionCallingOptions.builder().build(), null, List.of(), ObservationRegistry.NOOP); + this.chatModel = BedrockProxyChatModel.builder() + .bedrockRuntimeClient(this.bedrockRuntimeClient) + .bedrockRuntimeAsyncClient(this.bedrockRuntimeAsyncClient) + .build(); } @Test @@ -138,14 +137,13 @@ public class BedrockConverseUsageAggregationTests { given(this.bedrockRuntimeClient.converse(isA(ConverseRequest.class))).willReturn(converseResponseToolUse) .willReturn(converseResponseFinal); - FunctionCallback functionCallback = FunctionCallback.builder() - .function("getCurrentWeather", (Request request) -> "15.0°C") + ToolCallback toolCallback = FunctionToolCallback.builder("getCurrentWeather", (Request request) -> "15.0°C") .description("Gets the weather in location") .inputType(Request.class) .build(); var result = this.chatModel.call(new Prompt("What is the weather in Paris?", - FunctionCallingOptions.builder().functionCallbacks(functionCallback).build())); + ToolCallingChatOptions.builder().toolCallbacks(toolCallback).build())); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index 3652d322b..0e79c7d08 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -92,7 +92,7 @@ class BedrockProxyChatModelIT { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), - FunctionCallingOptions.builder().model(modelName).build()); + ToolCallingChatOptions.builder().model(modelName).build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0); @@ -128,7 +128,7 @@ class BedrockProxyChatModelIT { @Test void streamingWithTokenUsage() { - var promptOptions = FunctionCallingOptions.builder().temperature(0.0).build(); + var promptOptions = ToolCallingChatOptions.builder().temperature(0.0).build(); var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions); var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage(); @@ -255,9 +255,8 @@ class BedrockProxyChatModelIT { List messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = FunctionCallingOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + var promptOptions = ToolCallingChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) @@ -305,10 +304,9 @@ class BedrockProxyChatModelIT { List messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = FunctionCallingOptions.builder() + var promptOptions = ToolCallingChatOptions.builder() .model("anthropic.claude-3-5-sonnet-20240620-v1:0") - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) @@ -333,7 +331,7 @@ class BedrockProxyChatModelIT { String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() - .options(FunctionCallingOptions.builder().model(model).build()) + .options(ToolCallingChatOptions.builder().model(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() .chatResponse(); @@ -348,7 +346,7 @@ class BedrockProxyChatModelIT { String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() - .options(FunctionCallingOptions.builder().model(model).build()) + .options(ToolCallingChatOptions.builder().model(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .stream() .chatResponse() diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java index aefcf6a72..fb1b9c307 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java @@ -34,7 +34,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation import org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.beans.factory.annotation.Autowired; @@ -67,7 +67,7 @@ public class BedrockProxyChatModelObservationIT { @Test void observationForChatOperation() { - var options = FunctionCallingOptions.builder() + var options = ToolCallingChatOptions.builder() .model("anthropic.claude-3-5-sonnet-20240620-v1:0") .maxTokens(2048) .stopSequences(List.of("this-is-the-end")) @@ -89,7 +89,7 @@ public class BedrockProxyChatModelObservationIT { @Test void observationForStreamingChatOperation() { - var options = FunctionCallingOptions.builder() + var options = ToolCallingChatOptions.builder() .model("anthropic.claude-3-5-sonnet-20240620-v1:0") .maxTokens(2048) .stopSequences(List.of("this-is-the-end")) @@ -170,10 +170,10 @@ public class BedrockProxyChatModelObservationIT { String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0"; return BedrockProxyChatModel.builder() - .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) - .withRegion(Region.US_EAST_1) - .withObservationRegistry(observationRegistry) - .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .region(Region.US_EAST_1) + .observationRegistry(observationRegistry) + .defaultOptions(ToolCallingChatOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java index cd90c4368..167b72820 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java @@ -33,6 +33,7 @@ import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; @@ -185,7 +186,7 @@ public class BedrockNovaChatClientIT { .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) .region(Region.US_EAST_1) .timeout(Duration.ofSeconds(120)) - .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) + .defaultOptions(ToolCallingChatOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java index e9008c0ce..7253627b4 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain.java @@ -43,8 +43,8 @@ public final class BedrockConverseChatModelMain { var prompt = new Prompt("Tell me a joke?", ChatOptions.builder().model(modelId).build()); var chatModel = BedrockProxyChatModel.builder() - .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) - .withRegion(Region.US_EAST_1) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .region(Region.US_EAST_1) .build(); var chatResponse = chatModel.call(prompt); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java index daee808fb..8d408cf0f 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiments/BedrockConverseChatModelMain3.java @@ -24,8 +24,8 @@ import software.amazon.awssdk.regions.Region; import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; import org.springframework.ai.bedrock.converse.MockWeatherService; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; /** * Used for reverse engineering the protocol @@ -48,18 +48,17 @@ public final class BedrockConverseChatModelMain3 { // "What's the weather like in San Francisco, Tokyo, and Paris? Return the // temperature in Celsius.", "What's the weather like in Paris? Return the temperature in Celsius.", - FunctionCallingOptions.builder() + ToolCallingChatOptions.builder() .model(modelId) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) .build()); BedrockProxyChatModel chatModel = BedrockProxyChatModel.builder() - .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) - .withRegion(Region.US_EAST_1) + .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) + .region(Region.US_EAST_1) .build(); var response = chatModel.call(prompt); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index 18d8d3b70..33fa05b22 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -89,7 +89,7 @@ import org.springframework.util.MimeType; * @author Alexandros Pappas * @since 1.0.0 */ -public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel { +public class MistralAiChatModel implements ChatModel { private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention(); @@ -121,73 +121,9 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - /** - * @deprecated Use {@link MistralAiChatModel.Builder}. - */ - @Deprecated - public MistralAiChatModel(MistralAiApi mistralAiApi) { - this(mistralAiApi, - MistralAiChatOptions.builder() - .temperature(0.7) - .topP(1.0) - .safePrompt(false) - .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) - .build()); - } - - /** - * @deprecated Use {@link MistralAiChatModel.Builder}. - */ - @Deprecated - public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options) { - this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); - } - - /** - * @deprecated Use {@link MistralAiChatModel.Builder}. - */ - @Deprecated - public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, @Nullable RetryTemplate retryTemplate) { - this(mistralAiApi, options, functionCallbackResolver, List.of(), retryTemplate); - } - - /** - * @deprecated Use {@link MistralAiChatModel.Builder}. - */ - @Deprecated - public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate) { - this(mistralAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate, - ObservationRegistry.NOOP); - } - - /** - * @deprecated Use {@link MistralAiChatModel.Builder}. - */ - @Deprecated - public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate, - ObservationRegistry observationRegistry) { - this(mistralAiApi, options, - LegacyToolCallingManager.builder() - .functionCallbackResolver(functionCallbackResolver) - .functionCallbacks(toolFunctionCallbacks) - .build(), - retryTemplate, observationRegistry); - logger.warn("This constructor is deprecated and will be removed in the next milestone. " - + "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); - } - public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { - // We do not pass the 'defaultOptions' to the AbstractToolSupport, - // because it modifies them. We are using ToolCallingManager instead, - // so we just pass empty options here. - super(null, MistralAiChatOptions.builder().build(), List.of()); Assert.notNull(mistralAiApi, "mistralAiApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); @@ -594,15 +530,11 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM .temperature(0.7) .topP(1.0) .safePrompt(false) - .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) + .model(MistralAiApi.ChatModel.SMALL.getValue()) .build(); private ToolCallingManager toolCallingManager; - private FunctionCallbackResolver functionCallbackResolver; - - private List toolFunctionCallbacks; - private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -625,18 +557,6 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM return this; } - @Deprecated - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - @Deprecated - public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { - this.toolFunctionCallbacks = toolFunctionCallbacks; - return this; - } - public Builder retryTemplate(RetryTemplate retryTemplate) { this.retryTemplate = retryTemplate; return this; @@ -649,25 +569,9 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM public MistralAiChatModel build() { if (toolCallingManager != null) { - Assert.isNull(functionCallbackResolver, - "functionCallbackResolver cannot be set when toolCallingManager is set"); - Assert.isNull(toolFunctionCallbacks, - "toolFunctionCallbacks cannot be set when toolCallingManager is set"); - return new MistralAiChatModel(mistralAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry); } - - if (functionCallbackResolver != null) { - Assert.isNull(toolCallingManager, - "toolCallingManager cannot be set when functionCallbackResolver is set"); - List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks - : List.of(); - - return new MistralAiChatModel(mistralAiApi, defaultOptions, functionCallbackResolver, toolCallbacks, - retryTemplate, observationRegistry); - } - return new MistralAiChatModel(mistralAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, observationRegistry); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index 0be9617d5..4ca18dab2 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -265,12 +265,6 @@ public class MistralAiApi { public enum ChatModel implements ChatModelDescription { // @formatter:off - @Deprecated(forRemoval = true, since = "1.0.0-M6") - OPEN_MISTRAL_7B("open-mistral-7b"), - @Deprecated(forRemoval = true, since = "1.0.0-M6") - OPEN_MIXTRAL_7B("open-mixtral-8x7b"), - @Deprecated(forRemoval = true, since = "1.0.0-M6") - OPEN_MIXTRAL_22B("open-mixtral-8x22b"), // Premier Models CODESTRAL("codestral-latest"), LARGE("mistral-large-latest"), diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java index 702f84874..31070144b 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelObservationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,6 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.beans.factory.annotation.Autowired; @@ -70,7 +69,7 @@ public class MistralAiChatModelObservationIT { @Test void observationForChatOperation() { var options = MistralAiChatOptions.builder() - .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) + .model(MistralAiApi.ChatModel.SMALL.getValue()) .maxTokens(2048) .stop(List.of("this-is-the-end")) .temperature(0.7) @@ -91,7 +90,7 @@ public class MistralAiChatModelObservationIT { @Test void observationForStreamingChatOperation() { var options = MistralAiChatOptions.builder() - .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) + .model(MistralAiApi.ChatModel.SMALL.getValue()) .maxTokens(2048) .stop(List.of("this-is-the-end")) .temperature(0.7) @@ -125,12 +124,12 @@ public class MistralAiChatModelObservationIT { .doesNotHaveAnyRemainingCurrentObservation() .hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME) .that() - .hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) + .hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.SMALL.getValue()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.CHAT.value()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.MISTRAL_AI.value()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), - MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) + MistralAiApi.ChatModel.SMALL.getValue()) .hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(), StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel() : KeyValue.NONE_VALUE) @@ -181,9 +180,12 @@ public class MistralAiChatModelObservationIT { @Bean public MistralAiChatModel openAiChatModel(MistralAiApi mistralAiApi, TestObservationRegistry observationRegistry) { - return new MistralAiChatModel(mistralAiApi, MistralAiChatOptions.builder().build(), - new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(), - observationRegistry); + return MistralAiChatModel.builder() + .mistralAiApi(mistralAiApi) + .defaultOptions(MistralAiChatOptions.builder().build()) + .retryTemplate(RetryTemplate.defaultInstance()) + .observationRegistry(observationRegistry) + .build(); } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java index 27495508c..13dd882e3 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiRetryTests.java @@ -77,14 +77,16 @@ public class MistralAiRetryTests { this.retryListener = new TestRetryListener(); this.retryTemplate.registerListener(this.retryListener); - this.chatModel = new MistralAiChatModel(this.mistralAiApi, - MistralAiChatOptions.builder() - .temperature(0.7) - .topP(1.0) - .safePrompt(false) - .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()) - .build(), - null, this.retryTemplate); + this.chatModel = MistralAiChatModel.builder() + .mistralAiApi(this.mistralAiApi) + .defaultOptions(MistralAiChatOptions.builder() + .temperature(0.7) + .topP(1.0) + .safePrompt(false) + .model(MistralAiApi.ChatModel.SMALL.getValue()) + .build()) + .retryTemplate(this.retryTemplate) + .build(); this.embeddingModel = new MistralAiEmbeddingModel(this.mistralAiApi, MetadataMode.EMBED, MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(), this.retryTemplate); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java index 11c084ff6..e519df8c6 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiTestConfiguration.java @@ -43,8 +43,10 @@ public class MistralAiTestConfiguration { @Bean public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) { - return new MistralAiChatModel(mistralAiApi, - MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getValue()).build()); + return MistralAiChatModel.builder() + .mistralAiApi(mistralAiApi) + .defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL.getValue()).build()) + .build(); } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java index 48f57e8b9..fcdf12afe 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/api/MistralAiApiIT.java @@ -47,7 +47,7 @@ public class MistralAiApiIT { void chatCompletionEntity() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); ResponseEntity response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( - List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false)); + List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false)); assertThat(response).isNotNull(); assertThat(response.getBody()).isNotNull(); @@ -64,7 +64,7 @@ public class MistralAiApiIT { """, Role.SYSTEM); ResponseEntity response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest( - List.of(systemMessage, userMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false)); + List.of(systemMessage, userMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false)); assertThat(response).isNotNull(); assertThat(response.getBody()).isNotNull(); @@ -74,7 +74,7 @@ public class MistralAiApiIT { void chatCompletionStream() { ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER); Flux response = this.mistralAiApi.chatCompletionStream(new ChatCompletionRequest( - List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, true)); + List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, true)); assertThat(response).isNotNull(); assertThat(response.collectList().block()).isNotNull(); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index cf823fc97..72a768258 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -28,13 +28,6 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.model.tool.LegacyToolCallingManager; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.ai.model.tool.ToolExecutionResult; -import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.ai.util.json.JsonParser; -import org.springframework.lang.Nullable; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -45,7 +38,6 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatGenerationMetadata; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.DefaultUsage; -import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -57,9 +49,9 @@ import org.springframework.ai.chat.observation.DefaultChatModelObservationConven import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaApi.ChatRequest; import org.springframework.ai.ollama.api.OllamaApi.Message.Role; @@ -70,6 +62,8 @@ import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.management.ModelManagementOptions; import org.springframework.ai.ollama.management.OllamaModelManager; import org.springframework.ai.ollama.management.PullModelStrategy; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.util.json.JsonParser; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -89,7 +83,7 @@ import org.springframework.util.StringUtils; * @author Ilayaperumal Gopinathan * @since 1.0.0 */ -public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel { +public class OllamaChatModel implements ChatModel { private static final Logger logger = LoggerFactory.getLogger(OllamaChatModel.class); @@ -125,24 +119,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - @Deprecated - public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, - @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks, ObservationRegistry observationRegistry, - ModelManagementOptions modelManagementOptions) { - this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks), - observationRegistry, modelManagementOptions); - - logger.warn("This constructor is deprecated and will be removed in the next milestone. " - + "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead."); - } - public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { - // We do not pass the 'defaultOptions' to the AbstractToolSupport, - // because it modifies them. We are using ToolCallingManager instead, - // so we just pass empty options here. - super(null, OllamaOptions.builder().build(), List.of()); Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "defaultOptions must not be null"); Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); @@ -381,10 +359,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, OllamaOptions.class); } - else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) { - runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class, - OllamaOptions.class); - } else { runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, OllamaOptions.class); @@ -540,10 +514,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode private ToolCallingManager toolCallingManager; - private FunctionCallbackResolver functionCallbackResolver; - - private List toolFunctionCallbacks; - private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults(); @@ -566,18 +536,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode return this; } - @Deprecated - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - @Deprecated - public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { - this.toolFunctionCallbacks = toolFunctionCallbacks; - return this; - } - public Builder observationRegistry(ObservationRegistry observationRegistry) { this.observationRegistry = observationRegistry; return this; @@ -590,24 +548,9 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode public OllamaChatModel build() { if (toolCallingManager != null) { - Assert.isNull(functionCallbackResolver, - "functionCallbackResolver must not be set when toolCallingManager is set"); - Assert.isNull(toolFunctionCallbacks, - "toolFunctionCallbacks must not be set when toolCallingManager is set"); - return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager, this.observationRegistry, this.modelManagementOptions); } - - if (functionCallbackResolver != null) { - Assert.isNull(toolCallingManager, - "toolCallingManager must not be set when functionCallbackResolver is set"); - List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks - : List.of(); - return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver, - toolCallbacks, this.observationRegistry, this.modelManagementOptions); - } - return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, this.observationRegistry, this.modelManagementOptions); } diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java index dc5b9bfb7..a2cafcee6 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java @@ -55,9 +55,11 @@ class OllamaChatModelTests { @Test void buildOllamaChatModelWithDeprecatedConstructor() { - ChatModel chatModel = new OllamaChatModel(this.ollamaApi, - OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), null, null, ObservationRegistry.NOOP, - ModelManagementOptions.builder().build()); + ChatModel chatModel = OllamaChatModel.builder() + .ollamaApi(this.ollamaApi) + .defaultOptions(OllamaOptions.builder().model(OllamaModel.MISTRAL).build()) + .observationRegistry(ObservationRegistry.NOOP) + .build(); assertThat(chatModel).isNotNull(); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index c410f31dc..75ee33eb9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -29,12 +29,6 @@ import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.model.tool.LegacyToolCallingManager; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.ai.model.tool.ToolExecutionResult; -import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.lang.Nullable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -50,7 +44,6 @@ import org.springframework.ai.chat.metadata.EmptyUsage; import org.springframework.ai.chat.metadata.RateLimit; import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.metadata.UsageUtils; -import org.springframework.ai.chat.model.AbstractToolCallSupport; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; @@ -67,6 +60,9 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackResolver; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; @@ -79,6 +75,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.Resource; import org.springframework.http.ResponseEntity; @@ -111,7 +108,7 @@ import org.springframework.util.StringUtils; * @see StreamingChatModel * @see OpenAiApi */ -public class OpenAiChatModel extends AbstractToolCallSupport implements ChatModel { +public class OpenAiChatModel implements ChatModel { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class); @@ -146,96 +143,8 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; - /** - * Creates an instance of the OpenAiChatModel. - * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI - * Chat API. - * @throws IllegalArgumentException if openAiApi is null - * @deprecated Use OpenAiChatModel.Builder. - */ - @Deprecated - public OpenAiChatModel(OpenAiApi openAiApi) { - this(openAiApi, OpenAiChatOptions.builder().model(OpenAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build()); - } - - /** - * Initializes an instance of the OpenAiChatModel. - * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI - * Chat API. - * @param options The OpenAiChatOptions to configure the chat model. - * @deprecated Use OpenAiChatModel.Builder. - */ - @Deprecated - public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) { - this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); - } - - /** - * Initializes a new instance of the OpenAiChatModel. - * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI - * Chat API. - * @param options The OpenAiChatOptions to configure the chat model. - * @param functionCallbackResolver The function callback resolver. - * @param retryTemplate The retry template. - * @deprecated Use OpenAiChatModel.Builder. - */ - @Deprecated - public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) { - this(openAiApi, options, functionCallbackResolver, List.of(), retryTemplate); - } - - /** - * Initializes a new instance of the OpenAiChatModel. - * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI - * Chat API. - * @param options The OpenAiChatOptions to configure the chat model. - * @param functionCallbackResolver The function callback resolver. - * @param toolFunctionCallbacks The tool function callbacks. - * @param retryTemplate The retry template. - * @deprecated Use OpenAiChatModel.Builder. - */ - @Deprecated - public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate) { - this(openAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate, - ObservationRegistry.NOOP); - } - - /** - * Initializes a new instance of the OpenAiChatModel. - * @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI - * Chat API. - * @param options The OpenAiChatOptions to configure the chat model. - * @param functionCallbackResolver The function callback resolver. - * @param toolFunctionCallbacks The tool function callbacks. - * @param retryTemplate The retry template. - * @param observationRegistry The ObservationRegistry used for instrumentation. - * @deprecated Use OpenAiChatModel.Builder or OpenAiChatModel(OpenAiApi, - * OpenAiChatOptions, ToolCallingManager, RetryTemplate, ObservationRegistry). - */ - @Deprecated - public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options, - @Nullable FunctionCallbackResolver functionCallbackResolver, - @Nullable List toolFunctionCallbacks, RetryTemplate retryTemplate, - ObservationRegistry observationRegistry) { - this(openAiApi, options, - LegacyToolCallingManager.builder() - .functionCallbackResolver(functionCallbackResolver) - .functionCallbacks(toolFunctionCallbacks) - .build(), - retryTemplate, observationRegistry); - logger.warn("This constructor is deprecated and will be removed in the next milestone. " - + "Please use the OpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead."); - } - public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { - // We do not pass the 'defaultOptions' to the AbstractToolSupport, - // because it modifies them. We are using ToolCallingManager instead, - // so we just pass empty options here. - super(null, OpenAiChatOptions.builder().build(), List.of()); Assert.notNull(openAiApi, "openAiApi cannot be null"); Assert.notNull(defaultOptions, "defaultOptions cannot be null"); Assert.notNull(toolCallingManager, "toolCallingManager cannot be null"); @@ -777,10 +686,6 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode private ToolCallingManager toolCallingManager; - private FunctionCallbackResolver functionCallbackResolver; - - private List toolFunctionCallbacks; - private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -803,18 +708,6 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode return this; } - @Deprecated - public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) { - this.functionCallbackResolver = functionCallbackResolver; - return this; - } - - @Deprecated - public Builder toolFunctionCallbacks(List toolFunctionCallbacks) { - this.toolFunctionCallbacks = toolFunctionCallbacks; - return this; - } - public Builder retryTemplate(RetryTemplate retryTemplate) { this.retryTemplate = retryTemplate; return this; @@ -827,25 +720,9 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode public OpenAiChatModel build() { if (toolCallingManager != null) { - Assert.isNull(functionCallbackResolver, - "functionCallbackResolver cannot be set when toolCallingManager is set"); - Assert.isNull(toolFunctionCallbacks, - "toolFunctionCallbacks cannot be set when toolCallingManager is set"); - return new OpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate, observationRegistry); } - - if (functionCallbackResolver != null) { - Assert.isNull(toolCallingManager, - "toolCallingManager cannot be set when functionCallbackResolver is set"); - List toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks - : List.of(); - - return new OpenAiChatModel(openAiApi, defaultOptions, functionCallbackResolver, toolCallbacks, - retryTemplate, observationRegistry); - } - return new OpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate, observationRegistry); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 751a63aee..b320c1acc 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -85,97 +85,6 @@ public class OpenAiApi { private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper(); - /** - * Create a new chat completion api with base URL set to https://api.openai.com - * @param apiKey OpenAI apiKey. - * @deprecated since 1.0.0.M6 - use {@link #builder()} instead - */ - @Deprecated(since = "1.0.0.M6") - public OpenAiApi(String apiKey) { - this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey); - } - - /** - * Create a new chat completion api. - * @param baseUrl api base URL. - * @param apiKey OpenAI apiKey. - * @deprecated since 1.0.0.M6 - use {@link #builder()} instead - */ - @Deprecated(since = "1.0.0.M6") - public OpenAiApi(String baseUrl, String apiKey) { - this(baseUrl, apiKey, RestClient.builder(), WebClient.builder()); - } - - /** - * Create a new chat completion api. - * @param baseUrl api base URL. - * @param apiKey OpenAI apiKey. - * @param restClientBuilder RestClient builder. - * @param webClientBuilder WebClient builder. - * @deprecated since 1.0.0.M6 - use {@link #builder()} instead - */ - @Deprecated(since = "1.0.0.M6") - public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, - WebClient.Builder webClientBuilder) { - this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); - } - - /** - * Create a new chat completion api. - * @param baseUrl api base URL. - * @param apiKey OpenAI apiKey. - * @param restClientBuilder RestClient builder. - * @param webClientBuilder WebClient builder. - * @param responseErrorHandler Response error handler. - * @deprecated since 1.0.0.M6 - use {@link #builder()} instead - */ - @Deprecated(since = "1.0.0.M6") - public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, - WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { - this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder, - responseErrorHandler); - } - - /** - * Create a new chat completion api. - * @param baseUrl api base URL. - * @param apiKey OpenAI apiKey. - * @param completionsPath the path to the chat completions endpoint. - * @param embeddingsPath the path to the embeddings endpoint. - * @param restClientBuilder RestClient builder. - * @param webClientBuilder WebClient builder. - * @param responseErrorHandler Response error handler. - * @deprecated since 1.0.0.M6 - use {@link #builder()} instead - */ - @Deprecated(since = "1.0.0.M6") - public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath, - RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, - ResponseErrorHandler responseErrorHandler) { - - this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), completionsPath, embeddingsPath, - restClientBuilder, webClientBuilder, responseErrorHandler); - } - - /** - * Create a new chat completion api. - * @param baseUrl api base URL. - * @param apiKey OpenAI apiKey. - * @param headers the http headers to use. - * @param completionsPath the path to the chat completions endpoint. - * @param embeddingsPath the path to the embeddings endpoint. - * @param restClientBuilder RestClient builder. - * @param webClientBuilder WebClient builder. - * @param responseErrorHandler Response error handler. - * @deprecated since 1.0.0.M6 - use {@link #builder()} instead - */ - @Deprecated(since = "1.0.0.M6") - public OpenAiApi(String baseUrl, String apiKey, MultiValueMap headers, String completionsPath, - String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, - ResponseErrorHandler responseErrorHandler) { - this(baseUrl, new SimpleApiKey(apiKey), headers, completionsPath, embeddingsPath, restClientBuilder, - webClientBuilder, responseErrorHandler); - } - /** * Create a new chat completion api. * @param baseUrl api base URL. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index 33507bb99..6bb2eaf88 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -57,79 +57,6 @@ public class OpenAiAudioApi { private final WebClient webClient; - /** - * Create a new audio api. - * @param openAiToken OpenAI apiKey. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiAudioApi(String openAiToken) { - this(OpenAiApiConstants.DEFAULT_BASE_URL, openAiToken, RestClient.builder(), WebClient.builder(), - RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); - } - - /** - * Create a new audio api. - * @param baseUrl api base URL. - * @param openAiToken OpenAI apiKey. - * @param restClientBuilder RestClient builder. - * @param responseErrorHandler Response error handler. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, - ResponseErrorHandler responseErrorHandler) { - Consumer authHeaders; - if (openAiToken != null && !openAiToken.isEmpty()) { - authHeaders = h -> h.setBearerAuth(openAiToken); - } - else { - authHeaders = h -> { - }; - } - - this.restClient = restClientBuilder.baseUrl(baseUrl) - .defaultHeaders(authHeaders) - .defaultStatusHandler(responseErrorHandler) - .build(); - - this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(authHeaders).build(); - } - - /** - * Create a new audio api. - * @param baseUrl api base URL. - * @param apiKey OpenAI apiKey. - * @param restClientBuilder RestClient builder. - * @param webClientBuilder WebClient builder. - * @param responseErrorHandler Response error handler. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiAudioApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, - WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) { - - this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, webClientBuilder, - responseErrorHandler); - } - - /** - * Create a new audio api. - * @param baseUrl api base URL. - * @param apiKey OpenAI apiKey. - * @param headers the http headers to use. - * @param restClientBuilder RestClient builder. - * @param webClientBuilder WebClient builder. - * @param responseErrorHandler Response error handler. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiAudioApi(String baseUrl, String apiKey, MultiValueMap headers, - RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, - ResponseErrorHandler responseErrorHandler) { - this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, webClientBuilder, responseErrorHandler); - } - /** * Create a new audio api. * @param baseUrl api base URL. @@ -496,71 +423,26 @@ public class OpenAiAudioApi { private Float speed; - /** - * @deprecated use {@link #model(String)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withModel(String model) { - this.model = model; - return this; - } - public Builder model(String model) { this.model = model; return this; } - /** - * @deprecated use {@link #input(String)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withInput(String input) { - this.input = input; - return this; - } - public Builder input(String input) { this.input = input; return this; } - /** - * @deprecated use {@link #voice(Voice)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withVoice(Voice voice) { - this.voice = voice; - return this; - } - public Builder voice(Voice voice) { this.voice = voice; return this; } - /** - * @deprecated use {@link #responseFormat(AudioResponseFormat)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withResponseFormat(AudioResponseFormat responseFormat) { - this.responseFormat = responseFormat; - return this; - } - public Builder responseFormat(AudioResponseFormat responseFormat) { this.responseFormat = responseFormat; return this; } - /** - * @deprecated use {@link #speed(Float)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withSpeed(Float speed) { - this.speed = speed; - return this; - } - public Builder speed(Float speed) { this.speed = speed; return this; @@ -653,99 +535,36 @@ public class OpenAiAudioApi { private GranularityType granularityType; - /** - * @deprecated use {@link #file(byte[])} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withFile(byte[] file) { - this.file = file; - return this; - } - public Builder file(byte[] file) { this.file = file; return this; } - /** - * @deprecated use {@link #model(String)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withModel(String model) { - this.model = model; - return this; - } - public Builder model(String model) { this.model = model; return this; } - /** - * @deprecated use {@link #language(String)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withLanguage(String language) { - this.language = language; - return this; - } - public Builder language(String language) { this.language = language; return this; } - /** - * @deprecated use {@link #prompt(String)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withPrompt(String prompt) { - this.prompt = prompt; - return this; - } - public Builder prompt(String prompt) { this.prompt = prompt; return this; } - /** - * @deprecated use {@link #responseFormat(TranscriptResponseFormat)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withResponseFormat(TranscriptResponseFormat response_format) { - this.responseFormat = response_format; - return this; - } - public Builder responseFormat(TranscriptResponseFormat responseFormat) { this.responseFormat = responseFormat; return this; } - /** - * @deprecated use {@link #temperature(Float)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - public Builder temperature(Float temperature) { this.temperature = temperature; return this; } - /** - * @deprecated use {@link #granularityType(GranularityType)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withGranularityType(GranularityType granularityType) { - this.granularityType = granularityType; - return this; - } - public Builder granularityType(GranularityType granularityType) { this.granularityType = granularityType; return this; @@ -805,71 +624,26 @@ public class OpenAiAudioApi { private Float temperature; - /** - * @deprecated use {@link #file(byte[])} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withFile(byte[] file) { - this.file = file; - return this; - } - public Builder file(byte[] file) { this.file = file; return this; } - /** - * @deprecated use {@link #model(String)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withModel(String model) { - this.model = model; - return this; - } - public Builder model(String model) { this.model = model; return this; } - /** - * @deprecated use {@link #prompt(String)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withPrompt(String prompt) { - this.prompt = prompt; - return this; - } - public Builder prompt(String prompt) { this.prompt = prompt; return this; } - /** - * @deprecated use {@link #responseFormat(TranscriptResponseFormat)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withResponseFormat(TranscriptResponseFormat responseFormat) { - this.responseFormat = responseFormat; - return this; - } - public Builder responseFormat(TranscriptResponseFormat responseFormat) { this.responseFormat = responseFormat; return this; } - /** - * @deprecated use {@link #temperature(Float)} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public Builder withTemperature(Float temperature) { - this.temperature = temperature; - return this; - } - public Builder temperature(Float temperature) { this.temperature = temperature; return this; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index 56661f3e0..fdb3423d2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -17,7 +17,6 @@ package org.springframework.ai.openai.api; import java.util.List; -import java.util.Map; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -30,7 +29,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.client.ResponseErrorHandler; @@ -47,56 +45,6 @@ public class OpenAiImageApi { private final RestClient restClient; - /** - * Create a new OpenAI Image api with base URL set to {@code https://api.openai.com}. - * @param openAiToken OpenAI apiKey. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiImageApi(String openAiToken) { - this(OpenAiApiConstants.DEFAULT_BASE_URL, openAiToken, RestClient.builder()); - } - - /** - * Create a new OpenAI Image API with the provided base URL. - * @param baseUrl the base URL for the OpenAI API. - * @param openAiToken OpenAI apiKey. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { - this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); - } - - /** - * Create a new OpenAI Image API with the provided base URL. - * @param baseUrl the base URL for the OpenAI API. - * @param apiKey OpenAI apiKey. - * @param restClientBuilder the rest client builder to use. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiImageApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder, - ResponseErrorHandler responseErrorHandler) { - this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, responseErrorHandler); - } - - /** - * Create a new OpenAI Image API with the provided base URL. - * @param baseUrl the base URL for the OpenAI API. - * @param apiKey OpenAI apiKey. - * @param headers the http headers to use. - * @param restClientBuilder the rest client builder to use. - * @param responseErrorHandler the response error handler to use. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiImageApi(String baseUrl, String apiKey, MultiValueMap headers, - RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { - - this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, responseErrorHandler); - } - /** * Create a new OpenAI Image API with the provided base URL. * @param baseUrl the base URL for the OpenAI API. diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java index 1a509fc6b..ad0a3e962 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java @@ -52,33 +52,6 @@ public class OpenAiModerationApi { private final ObjectMapper objectMapper; - /** - * Create a new OpenAI Moderation api with base URL set to https://api.openai.com - * @param openAiToken OpenAI apiKey. - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiModerationApi(String openAiToken) { - this(DEFAULT_BASE_URL, openAiToken, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); - } - - /** - * @deprecated use {@link Builder} instead. - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - public OpenAiModerationApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, - ResponseErrorHandler responseErrorHandler) { - - this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - - this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> { - if (openAiToken != null && !openAiToken.isEmpty()) { - h.setBearerAuth(openAiToken); - } - h.setContentType(MediaType.APPLICATION_JSON); - }).defaultStatusHandler(responseErrorHandler).build(); - } - /** * Create a new OpenAI Moderation API with the provided base URL. * @param baseUrl the base URL for the OpenAI API. diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 561b0fe95..84dc1c567 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -74,8 +74,10 @@ class ChatCompletionRequestTests { @Test void createRequestWithChatOptions() { - var client = new OpenAiChatModel(new OpenAiApi("TEST"), - OpenAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()); + var client = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder().apiKey("TEST").build()) + .defaultOptions(OpenAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build()) + .build(); var prompt = client.buildRequestPrompt(new Prompt("Test message content")); @@ -101,8 +103,10 @@ class ChatCompletionRequestTests { void promptOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new OpenAiChatModel(new OpenAiApi("TEST"), - OpenAiChatOptions.builder().model("DEFAULT_MODEL").build()); + var client = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder().apiKey("TEST").build()) + .defaultOptions(OpenAiChatOptions.builder().model("DEFAULT_MODEL").build()) + .build(); var prompt = client.buildRequestPrompt(new Prompt("Test message content", OpenAiChatOptions.builder() @@ -128,15 +132,17 @@ class ChatCompletionRequestTests { void defaultOptionsTools() { final String TOOL_FUNCTION_NAME = "CurrentWeather"; - var client = new OpenAiChatModel(new OpenAiApi("TEST"), - OpenAiChatOptions.builder() - .model("DEFAULT_MODEL") - .functionCallbacks(List.of(FunctionCallback.builder() - .function(TOOL_FUNCTION_NAME, new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) - .build()); + var client = OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder().apiKey("TEST").build()) + .defaultOptions(OpenAiChatOptions.builder() + .model("DEFAULT_MODEL") + .functionCallbacks(List.of(FunctionCallback.builder() + .function(TOOL_FUNCTION_NAME, new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) + .build()) + .build(); var prompt = client.buildRequestPrompt(new Prompt("Test message content")); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index d669d4ad2..e7401d9d8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -61,28 +61,25 @@ public class OpenAiTestConfiguration { @Bean public OpenAiChatModel openAiChatModel(OpenAiApi api) { - OpenAiChatModel openAiChatModel = new OpenAiChatModel(api, - OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI).build()); - return openAiChatModel; + return OpenAiChatModel.builder() + .openAiApi(api) + .defaultOptions(OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI).build()) + .build(); } @Bean public OpenAiAudioTranscriptionModel openAiTranscriptionModel(OpenAiAudioApi api) { - OpenAiAudioTranscriptionModel openAiTranscriptionModel = new OpenAiAudioTranscriptionModel(api); - return openAiTranscriptionModel; + return new OpenAiAudioTranscriptionModel(api); } @Bean public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) { - OpenAiAudioSpeechModel openAiAudioSpeechModel = new OpenAiAudioSpeechModel(api); - return openAiAudioSpeechModel; + return new OpenAiAudioSpeechModel(api); } @Bean public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) { - OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi); - // openAiImageModel.setModel("foobar"); - return openAiImageModel; + return new OpenAiImageModel(imageApi); } @Bean @@ -92,8 +89,7 @@ public class OpenAiTestConfiguration { @Bean public OpenAiModerationModel openAiModerationClient(OpenAiModerationApi openAiModerationApi) { - OpenAiModerationModel openAiModerationModel = new OpenAiModerationModel(openAiModerationApi); - return openAiModerationModel; + return new OpenAiModerationModel(openAiModerationApi); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index 71a3e7497..f843c0c73 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -46,7 +46,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiApiIT { - OpenAiApi openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); + OpenAiApi openAiApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); @Test void chatCompletionEntity() { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index 33f50b1f4..482fc1157 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -52,7 +52,7 @@ public class OpenAiApiToolFunctionCallIT { MockWeatherService weatherService = new MockWeatherService(); - OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); + OpenAiApi completionApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); private static T fromJson(String json, Class targetClass) { try { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java index 65cc583fd..4a48c977f 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/MessageTypeContentTests.java @@ -74,7 +74,7 @@ public class MessageTypeContentTests { @BeforeEach public void beforeEach() { - this.chatModel = new OpenAiChatModel(this.openAiApi); + this.chatModel = OpenAiChatModel.builder().openAiApi(this.openAiApi).build(); } @Test diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelAdditionalHttpHeadersIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelAdditionalHttpHeadersIT.java index 07af49fba..b1acd9694 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelAdditionalHttpHeadersIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelAdditionalHttpHeadersIT.java @@ -73,7 +73,7 @@ public class OpenAiChatModelAdditionalHttpHeadersIT { @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi); + return OpenAiChatModel.builder().openAiApi(openAiApi).build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java index 799b76ea1..e62358b4c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelFunctionCallingIT.java @@ -219,12 +219,12 @@ class OpenAiChatModelFunctionCallingIT { @Bean public OpenAiApi chatCompletionApi() { - return new OpenAiApi(System.getenv("OPENAI_API_KEY")); + return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); } @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi); + return OpenAiChatModel.builder().openAiApi(openAiApi).build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelNoOpApiKeysIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelNoOpApiKeysIT.java index 9a73f1105..5caf74c1c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelNoOpApiKeysIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelNoOpApiKeysIT.java @@ -62,7 +62,7 @@ public class OpenAiChatModelNoOpApiKeysIT { @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi); + return OpenAiChatModel.builder().openAiApi(openAiApi).build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java index e6e58499e..33e03e447 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelObservationIT.java @@ -31,6 +31,8 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.openai.OpenAiChatModel; @@ -169,14 +171,13 @@ public class OpenAiChatModelObservationIT { @Bean public OpenAiApi openAiApi() { - return new OpenAiApi(System.getenv("OPENAI_API_KEY")); + return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); } @Bean public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi, TestObservationRegistry observationRegistry) { return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().build(), - new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(), - observationRegistry); + ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java deleted file mode 100644 index db047a3cf..000000000 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelProxyToolCallsIT.java +++ /dev/null @@ -1,372 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.openai.chat; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.JsonMappingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.micrometer.observation.ObservationRegistry; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.ToolResponseMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingHelper; -import org.springframework.ai.openai.OpenAiChatModel; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.retry.RetryUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.SpringBootConfiguration; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.context.annotation.Bean; -import org.springframework.util.CollectionUtils; - -import static org.assertj.core.api.Assertions.assertThat; - -@SpringBootTest(classes = OpenAiChatModelProxyToolCallsIT.Config.class) -@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") -class OpenAiChatModelProxyToolCallsIT { - - private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelProxyToolCallsIT.class); - - private static final String DEFAULT_MODEL = "gpt-4o-mini"; - - FunctionCallback functionDefinition = new FunctionCallingHelper.FunctionDefinition("getWeatherInLocation", - "Get the weather in location", """ - { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["C", "F"] - } - }, - "required": ["location", "unit"] - } - """); - - @Autowired - private OpenAiChatModel chatModel; - - // Helper class that reuses some of the {@link AbstractToolCallSupport} functionality - // to help to implement the function call handling logic on the client side. - private FunctionCallingHelper functionCallingHelper = new FunctionCallingHelper(); - - @SuppressWarnings("unchecked") - private static Map getFunctionArguments(String functionArguments) { - try { - return new ObjectMapper().readValue(functionArguments, Map.class); - } - catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - // Function which will be called by the AI model. - private String getWeatherInLocation(String location, String unit) { - - double temperature = 0; - - if (location.contains("Paris")) { - temperature = 15; - } - else if (location.contains("Tokyo")) { - temperature = 10; - } - else if (location.contains("San Francisco")) { - temperature = 30; - } - - return String.format("The weather in %s is %s%s", location, temperature, unit); - } - - @Test - void functionCall() throws JsonMappingException, JsonProcessingException { - - List messages = List - .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - - var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build(); - - var prompt = new Prompt(messages, promptOptions); - - boolean isToolCall = false; - - ChatResponse chatResponse = null; - - do { - - chatResponse = this.chatModel.call(prompt); - - // We will have to convert the chatResponse into OpenAI assistant message. - - // Note that the tool call check could be platform specific because the finish - // reasons. - isToolCall = this.functionCallingHelper.isToolCall(chatResponse, - Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name())); - - if (isToolCall) { - - Optional toolCallGeneration = chatResponse.getResults() - .stream() - .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) - .findFirst(); - - assertThat(toolCallGeneration).isNotEmpty(); - - AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); - - List toolResponses = new ArrayList<>(); - - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - - var functionName = toolCall.name(); - - assertThat(functionName).isEqualTo("getWeatherInLocation"); - - String functionArguments = toolCall.arguments(); - - @SuppressWarnings("unchecked") - Map argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class); - - String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), - argumentsMap.get("unit").toString()); - - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, - ModelOptionsUtils.toJsonString(functionResponse))); - } - - ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - - List toolCallConversation = this.functionCallingHelper - .buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); - - assertThat(toolCallConversation).isNotEmpty(); - - prompt = new Prompt(toolCallConversation, prompt.getOptions()); - } - } - while (isToolCall); - - logger.info("Response: {}", chatResponse); - - assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15"); - } - - @Test - void functionStream() throws JsonMappingException, JsonProcessingException { - - List messages = List - .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - - var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build(); - - var prompt = new Prompt(messages, promptOptions); - - String response = processToolCall(prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> { - - var functionName = toolCall.name(); - - assertThat(functionName).isEqualTo("getWeatherInLocation"); - - String functionArguments = toolCall.arguments(); - - Map argumentsMap = getFunctionArguments(functionArguments); - - String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), - argumentsMap.get("unit").toString()); - - return functionResponse; - }) - .collectList() - .block() - .stream() - .map(cr -> cr.getResult().getOutput().getText()) - .collect(Collectors.joining()); - - logger.info("Response: {}", response); - - assertThat(response).contains("30", "10", "15"); - - } - - private Flux processToolCall(Prompt prompt, Set finishReasons, - Function customFunction) { - - Flux chatResponses = this.chatModel.stream(prompt); - - return chatResponses.flatMap(chatResponse -> { - - boolean isToolCall = this.functionCallingHelper.isToolCall(chatResponse, finishReasons); - - if (isToolCall) { - - Optional toolCallGeneration = chatResponse.getResults() - .stream() - .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) - .findFirst(); - - assertThat(toolCallGeneration).isNotEmpty(); - - AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); - - List toolResponses = new ArrayList<>(); - - for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { - - String functionResponse = customFunction.apply(toolCall); - - toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), - ModelOptionsUtils.toJsonString(functionResponse))); - } - - ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); - - List toolCallConversation = this.functionCallingHelper - .buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse); - - assertThat(toolCallConversation).isNotEmpty(); - - var prompt2 = new Prompt(toolCallConversation, prompt.getOptions()); - - return processToolCall(prompt2, finishReasons, customFunction); - } - - return Flux.just(chatResponse); - }); - } - - @Test - void functionCall2() throws JsonMappingException, JsonProcessingException { - - List messages = List - .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - - var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build(); - - var prompt = new Prompt(messages, promptOptions); - - ChatResponse chatResponse = this.functionCallingHelper.processCall(this.chatModel, prompt, - Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()), - toolCall -> { - - var functionName = toolCall.name(); - - assertThat(functionName).isEqualTo("getWeatherInLocation"); - - String functionArguments = toolCall.arguments(); - - Map argumentsMap = getFunctionArguments(functionArguments); - - String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), - argumentsMap.get("unit").toString()); - - return functionResponse; - }); - - logger.info("Response: {}", chatResponse); - - assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15"); - } - - @Test - void functionStream2() throws JsonMappingException, JsonProcessingException { - - List messages = List - .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); - - var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build(); - - var prompt = new Prompt(messages, promptOptions); - - Flux responses = this.functionCallingHelper.processStream(this.chatModel, prompt, - Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), - OpenAiApi.ChatCompletionFinishReason.STOP.name()), - toolCall -> { - - var functionName = toolCall.name(); - - assertThat(functionName).isEqualTo("getWeatherInLocation"); - - String functionArguments = toolCall.arguments(); - - Map argumentsMap = getFunctionArguments(functionArguments); - - String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), - argumentsMap.get("unit").toString()); - - return functionResponse; - }); - - String response = responses.collectList() - .block() - .stream() - .map(cr -> cr.getResult().getOutput().getText()) - .collect(Collectors.joining()); - - logger.info("Response: {}", response); - - assertThat(response).contains("30", "10", "15"); - - } - - @SpringBootConfiguration - static class Config { - - @Bean - public OpenAiApi chatCompletionApi() { - return new OpenAiApi(System.getenv("OPENAI_API_KEY")); - } - - @Bean - public OpenAiChatModel openAiClient(OpenAiApi openAiApi, List toolFunctionCallbacks) { - // enable the proxy tool calls option. - var options = OpenAiChatOptions.builder().model(DEFAULT_MODEL).proxyToolCalls(true).build(); - - return new OpenAiChatModel(openAiApi, options, null, toolFunctionCallbacks, - RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP); - } - - } - -} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java index f2f286496..65a3e2389 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java @@ -234,12 +234,12 @@ public class OpenAiChatModelResponseFormatIT { @Bean public OpenAiApi chatCompletionApi() { - return new OpenAiApi(System.getenv("OPENAI_API_KEY")); + return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); } @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi); + return OpenAiChatModel.builder().openAiApi(openAiApi).build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java index 76dd5a7ee..f606e16ad 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java @@ -18,6 +18,7 @@ package org.springframework.ai.openai.chat; import java.time.Duration; +import org.hamcrest.core.StringContains; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; @@ -133,7 +134,7 @@ public class OpenAiChatModelWithChatResponseMetadataTests { httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358"); httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms"); - this.server.expect(requestTo("/v1/chat/completions")) + this.server.expect(requestTo(StringContains.containsString("/v1/chat/completions"))) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); @@ -169,12 +170,16 @@ public class OpenAiChatModelWithChatResponseMetadataTests { @Bean public OpenAiApi chatCompletionApi(RestClient.Builder builder, WebClient.Builder webClientBuilder) { - return new OpenAiApi("", TEST_API_KEY, builder, webClientBuilder); + return OpenAiApi.builder() + .apiKey(TEST_API_KEY) + .restClientBuilder(builder) + .webClientBuilder(webClientBuilder) + .build(); } @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi); + return OpenAiChatModel.builder().openAiApi(openAiApi).build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java index 56cd6fa21..6ac14d119 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiCompatibleChatModelIT.java @@ -55,7 +55,10 @@ public class OpenAiCompatibleChatModelIT { static Stream openAiCompatibleApis() { Stream.Builder builder = Stream.builder(); - builder.add(new OpenAiChatModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")), forModelName("gpt-3.5-turbo"))); + builder.add(OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()) + .defaultOptions(forModelName("gpt-3.5-turbo")) + .build()); // (26.01.2025) Disable because the Groq API is down. TODO: Re-enable when the API // is back up. @@ -66,9 +69,13 @@ public class OpenAiCompatibleChatModelIT { // } if (System.getenv("OPEN_ROUTER_API_KEY") != null) { - builder.add(new OpenAiChatModel( - new OpenAiApi("https://openrouter.ai/api", System.getenv("OPEN_ROUTER_API_KEY")), - forModelName("meta-llama/llama-3-8b-instruct"))); + builder.add(OpenAiChatModel.builder() + .openAiApi(OpenAiApi.builder() + .baseUrl("https://openrouter.ai/api") + .apiKey(System.getenv("OPEN_ROUTER_API_KEY")) + .build()) + .defaultOptions(forModelName("meta-llama/llama-3-8b-instruct")) + .build()); } return builder.build(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java index f2a92a1ef..0843df107 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiPaymentTransactionIT.java @@ -16,11 +16,13 @@ package org.springframework.ai.openai.chat; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -34,19 +36,28 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.converter.BeanOutputConverter; -import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallbackResolver; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; +import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; +import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; +import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; +import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Description; +import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.ParameterizedTypeReference; import static org.assertj.core.api.Assertions.assertThat; @@ -71,7 +82,7 @@ public class OpenAiPaymentTransactionIT { public void transactionPaymentStatuses(String functionName) { List content = this.chatClient.prompt() .advisors(new LoggingAdvisor()) - .functions(functionName) + .tools(functionName) .user(""" What is the status of my payment transactions 001, 002 and 003? """) @@ -102,7 +113,7 @@ public class OpenAiPaymentTransactionIT { Flux flux = this.chatClient.prompt() .advisors(new LoggingAdvisor()) - .functions(functionName) + .tools(functionName) .user(u -> u.text(""" What is the status of my payment transactions 001, 002 and 003? @@ -217,25 +228,56 @@ public class OpenAiPaymentTransactionIT { @Bean public OpenAiApi chatCompletionApi() { - return new OpenAiApi(System.getenv("OPENAI_API_KEY")); + return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); } @Bean - public OpenAiChatModel openAiClient(OpenAiApi openAiApi, FunctionCallbackResolver functionCallbackResolver) { - return new OpenAiChatModel(openAiApi, - OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI.getName()).temperature(0.1).build(), - functionCallbackResolver, RetryUtils.DEFAULT_RETRY_TEMPLATE); + public OpenAiChatModel openAiClient(OpenAiApi openAiApi, ToolCallingManager toolCallingManager) { + return OpenAiChatModel.builder() + .openAiApi(openAiApi) + .toolCallingManager(toolCallingManager) + .defaultOptions( + OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI.getName()).temperature(0.1).build()) + .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) + .build(); } - /** - * Because of the OPEN_API_SCHEMA type, the FunctionCallbackResolver instance must - * different from the other JSON schema types. - */ @Bean - public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) { - DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver(); - manager.setApplicationContext(context); - return manager; + @ConditionalOnMissingBean + ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, + List functionCallbacks, List tcbProviders) { + + List allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks); + tcbProviders.stream() + .map(pr -> List.of(pr.getToolCallbacks())) + .forEach(allFunctionAndToolCallbacks::addAll); + + var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks); + + var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder() + .applicationContext(applicationContext) + .build(); + + return new DelegatingToolCallbackResolver( + List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); + } + + @Bean + @ConditionalOnMissingBean + ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() { + return new DefaultToolExecutionExceptionProcessor(false); + } + + @Bean + @ConditionalOnMissingBean + ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver, + ToolExecutionExceptionProcessor toolExecutionExceptionProcessor, + ObjectProvider observationRegistry) { + return ToolCallingManager.builder() + .observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) + .toolCallbackResolver(toolCallbackResolver) + .toolExecutionExceptionProcessor(toolExecutionExceptionProcessor) + .build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index 20ca78309..2ab1a9517 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -110,8 +110,11 @@ public class OpenAiRetryTests { this.retryListener = new TestRetryListener(); this.retryTemplate.registerListener(this.retryListener); - this.chatModel = new OpenAiChatModel(this.openAiApi, OpenAiChatOptions.builder().build(), null, - this.retryTemplate); + this.chatModel = OpenAiChatModel.builder() + .openAiApi(this.openAiApi) + .defaultOptions(OpenAiChatOptions.builder().build()) + .retryTemplate(this.retryTemplate) + .build(); this.embeddingModel = new OpenAiEmbeddingModel(this.openAiApi, MetadataMode.EMBED, OpenAiEmbeddingOptions.builder().build(), this.retryTemplate); this.audioTranscriptionModel = new OpenAiAudioTranscriptionModel(this.openAiAudioApi, diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java index 461b8cdaf..975a172d2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java @@ -330,12 +330,15 @@ class DeepSeekWithOpenAiChatModelIT { @Bean public OpenAiApi chatCompletionApi() { - return new OpenAiApi(DEEPSEEK_BASE_URL, System.getenv("DEEPSEEK_API_KEY")); + return OpenAiApi.builder().baseUrl(DEEPSEEK_BASE_URL).apiKey(System.getenv("DEEPSEEK_API_KEY")).build(); } @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_DEEPSEEK_MODEL).build()); + return OpenAiChatModel.builder() + .openAiApi(openAiApi) + .defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_DEEPSEEK_MODEL).build()) + .build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java index 94b4c0f07..2e6208398 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/GroqWithOpenAiChatModelIT.java @@ -383,12 +383,15 @@ class GroqWithOpenAiChatModelIT { @Bean public OpenAiApi chatCompletionApi() { - return new OpenAiApi(GROQ_BASE_URL, System.getenv("GROQ_API_KEY")); + return OpenAiApi.builder().baseUrl(GROQ_BASE_URL).apiKey(System.getenv("GROQ_API_KEY")).build(); } @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_GROQ_MODEL).build()); + return OpenAiChatModel.builder() + .openAiApi(openAiApi) + .defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_GROQ_MODEL).build()) + .build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java index 4983e96c7..88263c923 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java @@ -318,13 +318,15 @@ class NvidiaWithOpenAiChatModelIT { @Bean public OpenAiApi chatCompletionApi() { - return new OpenAiApi(NVIDIA_BASE_URL, System.getenv("NVIDIA_API_KEY")); + return OpenAiApi.builder().baseUrl(NVIDIA_BASE_URL).apiKey(System.getenv("NVIDIA_API_KEY")).build(); } @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi, - OpenAiChatOptions.builder().maxTokens(2048).model(DEFAULT_NVIDIA_MODEL).build()); + return OpenAiChatModel.builder() + .openAiApi(openAiApi) + .defaultOptions(OpenAiChatOptions.builder().maxTokens(2048).model(DEFAULT_NVIDIA_MODEL).build()) + .build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java index ceb6c9441..6b4e7fbd4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/PerplexityWithOpenAiChatModelIT.java @@ -327,14 +327,20 @@ class PerplexityWithOpenAiChatModelIT { @Bean public OpenAiApi chatCompletionApi() { - return new OpenAiApi(PERPLEXITY_BASE_URL, System.getenv("PERPLEXITY_API_KEY"), PERPLEXITY_COMPLETIONS_PATH, - "/v1/embeddings", RestClient.builder(), WebClient.builder(), - RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + return OpenAiApi.builder() + .baseUrl(PERPLEXITY_BASE_URL) + .apiKey(System.getenv("PERPLEXITY_API_KEY")) + .completionsPath(PERPLEXITY_COMPLETIONS_PATH) + .embeddingsPath("/v1/embeddings") + .build(); } @Bean public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { - return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_PERPLEXITY_MODEL).build()); + return OpenAiChatModel.builder() + .openAiApi(openAiApi) + .defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_PERPLEXITY_MODEL).build()) + .build(); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java index 684d9a92f..aa76a67f7 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/OpenAiEmbeddingModelObservationIT.java @@ -104,7 +104,7 @@ public class OpenAiEmbeddingModelObservationIT { @Bean public OpenAiApi openAiApi() { - return new OpenAiApi(System.getenv("OPENAI_API_KEY")); + return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build(); } @Bean diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java index 66a8d3d8e..5df6f389b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/transformer/MetadataTransformerIT.java @@ -163,13 +163,12 @@ public class MetadataTransformerIT { throw new IllegalArgumentException( "You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY"); } - return new OpenAiApi(apiKey); + return OpenAiApi.builder().apiKey(apiKey).build(); } @Bean public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) { - OpenAiChatModel openAiChatModel = new OpenAiChatModel(openAiApi); - return openAiChatModel; + return OpenAiChatModel.builder().openAiApi(openAiApi).build(); } @Bean diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java deleted file mode 100644 index 96ff923bc..000000000 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiPaymentTransactionDeprecatedIT.java +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2023-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.vertexai.gemini.tool; - -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import java.util.stream.Collectors; - -import com.google.cloud.vertexai.Transport; -import com.google.cloud.vertexai.VertexAI; -import org.junit.jupiter.api.RepeatedTest; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; - -import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; -import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; -import org.springframework.ai.model.function.DefaultFunctionCallbackResolver; -import org.springframework.ai.model.function.FunctionCallback.SchemaType; -import org.springframework.ai.model.function.FunctionCallbackResolver; -import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; -import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.SpringBootConfiguration; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.context.ApplicationContext; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Description; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Christian Tzolov - */ -@SpringBootTest -@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*") -@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*") -@Deprecated -public class VertexAiGeminiPaymentTransactionDeprecatedIT { - - private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionDeprecatedIT.class); - - private static final Map DATASET = Map.of(new Transaction("001"), new Status("pending"), - new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected")); - - @Autowired - ChatClient chatClient; - - @Test - public void paymentStatuses() { - // @formatter:off - String content = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) - .tools("paymentStatus") - .user(""" - What is the status of my payment transactions 001, 002 and 003? - If requred invoke the function per transaction. - """).call().content(); - - logger.info("" + content); - - assertThat(content).contains("001", "002", "003"); - assertThat(content).contains("pending", "approved", "rejected"); - } - - @RepeatedTest(5) - public void streamingPaymentStatuses() { - - Flux streamContent = this.chatClient.prompt() - .advisors(new LoggingAdvisor()) - .tools("paymentStatus") - .user(""" - What is the status of my payment transactions 001, 002 and 003? - If requred invoke the function per transaction. - """) - .stream() - .content(); - - String content = streamContent.collectList().block().stream().collect(Collectors.joining()); - - logger.info(content); - - assertThat(content).contains("001", "002", "003"); - assertThat(content).contains("pending", "approved", "rejected"); - - // Quota rate - try { - Thread.sleep(1000); - } - catch (InterruptedException e) { - } - } - - record TransactionStatusResponse(String id, String status) { - - } - - private static class LoggingAdvisor implements CallAroundAdvisor { - - private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class); - - @Override - public String getName() { - return this.getClass().getSimpleName(); - } - - @Override - public int getOrder() { - return 0; - } - - @Override - public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) { - var response = chain.nextAroundCall(before(advisedRequest)); - observeAfter(response); - return response; - } - - private AdvisedRequest before(AdvisedRequest request) { - logger.info("System text: \n" + request.systemText()); - logger.info("System params: " + request.systemParams()); - logger.info("User text: \n" + request.userText()); - logger.info("User params:" + request.userParams()); - logger.info("Function names: " + request.functionNames()); - - logger.info("Options: " + request.chatOptions().toString()); - - return request; - } - - private void observeAfter(AdvisedResponse advisedResponse) { - logger.info("Response: " + advisedResponse.response()); - } - - } - - record Transaction(String id) { - } - - record Status(String name) { - } - - record Transactions(List transactions) { - } - - record Statuses(List statuses) { - } - - @SpringBootConfiguration - public static class TestConfiguration { - - @Bean - @Description("Get the status of a single payment transaction") - public Function paymentStatus() { - return transaction -> { - logger.info("Single Transaction: " + transaction); - return DATASET.get(transaction); - }; - } - - @Bean - @Description("Get the list statuses of a list of payment transactions") - public Function paymentStatuses() { - return transactions -> { - logger.info("Transactions: " + transactions); - return new Statuses(transactions.transactions().stream().map(t -> DATASET.get(t)).toList()); - }; - } - - @Bean - public ChatClient chatClient(VertexAiGeminiChatModel chatModel) { - return ChatClient.builder(chatModel).build(); - } - - @Bean - public VertexAI vertexAiApi() { - - String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"); - String location = System.getenv("VERTEX_AI_GEMINI_LOCATION"); - - return new VertexAI.Builder().setLocation(location) - .setProjectId(projectId) - .setTransport(Transport.REST) - // .setTransport(Transport.GRPC) - .build(); - } - - @Bean - public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ApplicationContext context) { - - FunctionCallbackResolver functionCallbackResolver = springAiFunctionManager(context); - - return new VertexAiGeminiChatModel(vertexAi, - VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH) - .temperature(0.1) - .build(), - functionCallbackResolver); - } - - /** - * Because of the OPEN_API_SCHEMA type, the FunctionCallbackResolver instance - * must - * different from the other JSON schema types. - */ - private FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) { - DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver(); - manager.setSchemaType(SchemaType.OPEN_API_SCHEMA); - manager.setApplicationContext(context); - return manager; - } - - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java index 5b083a121..e9a812f44 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/DefaultUsage.java @@ -30,19 +30,13 @@ import com.fasterxml.jackson.annotation.JsonPropertyOrder; * @author Ilayaperumal Gopinathan * @since 1.0.0 */ -@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "generationTokens", "nativeUsage" }) +@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "nativeUsage" }) public class DefaultUsage implements Usage { private final Integer promptTokens; private final Integer completionTokens; - /** - * @deprecated as of 1.0.0-M6, scheduled for removal - */ - @Deprecated(forRemoval = true, since = "1.0.0-M6") - private final Long generationTokens; - private final int totalTokens; private final Object nativeUsage; @@ -62,7 +56,6 @@ public class DefaultUsage implements Usage { public DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens, Object nativeUsage) { this.promptTokens = promptTokens != null ? promptTokens : 0; this.completionTokens = completionTokens != null ? completionTokens : 0; - this.generationTokens = Long.valueOf(this.completionTokens); this.totalTokens = totalTokens != null ? totalTokens : calculateTotalTokens(this.promptTokens, this.completionTokens); this.nativeUsage = nativeUsage; @@ -106,11 +99,8 @@ public class DefaultUsage implements Usage { @JsonCreator public static DefaultUsage fromJson(@JsonProperty("promptTokens") Integer promptTokens, @JsonProperty("completionTokens") Integer completionTokens, - @JsonProperty("generationTokens") Long generationTokens, @JsonProperty("totalTokens") Integer totalTokens, - @JsonProperty("nativeUsage") Object nativeUsage) { - Integer effectiveCompletionTokens = completionTokens != null ? completionTokens - : (generationTokens != null ? generationTokens.intValue() : 0); - return new DefaultUsage(promptTokens, effectiveCompletionTokens, totalTokens, nativeUsage); + @JsonProperty("totalTokens") Integer totalTokens, @JsonProperty("nativeUsage") Object nativeUsage) { + return new DefaultUsage(promptTokens, completionTokens, totalTokens, nativeUsage); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java index fa43712ff..2afe495d7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/metadata/Usage.java @@ -34,11 +34,6 @@ public interface Usage { */ Integer getPromptTokens(); - @Deprecated(forRemoval = true, since = "1.0.0-M6") - default Long getGenerationTokens() { - return getCompletionTokens().longValue(); - } - /** * Returns the number of tokens returned in the {@literal generation (aka completion)} * of the AI's response. diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java index 3a9172117..bf33c530e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/metadata/DefaultUsageTests.java @@ -32,13 +32,12 @@ public class DefaultUsageTests { void testSerializationWithAllFields() throws Exception { DefaultUsage usage = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150)); String json = this.objectMapper.writeValueAsString(usage); - assertThat(json) - .isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}"); + assertThat(json).isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}"); } @Test void testDeserializationWithAllFields() throws Exception { - String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}"; + String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}"; DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertThat(usage.getPromptTokens()).isEqualTo(100); assertThat(usage.getCompletionTokens()).isEqualTo(50); @@ -49,8 +48,7 @@ public class DefaultUsageTests { void testSerializationWithNullFields() throws Exception { DefaultUsage usage = new DefaultUsage((Integer) null, (Integer) null, (Integer) null); String json = this.objectMapper.writeValueAsString(usage); - assertThat(json) - .isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}"); + assertThat(json).isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0}"); } @Test @@ -92,8 +90,7 @@ public class DefaultUsageTests { // Test serialization String json = this.objectMapper.writeValueAsString(usage); - assertThat(json) - .isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}"); + assertThat(json).isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}"); // Test deserialization DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); @@ -113,8 +110,7 @@ public class DefaultUsageTests { // Test serialization String json = this.objectMapper.writeValueAsString(usage); - assertThat(json) - .isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}"); + assertThat(json).isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0}"); // Test deserialization DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class); @@ -123,18 +119,9 @@ public class DefaultUsageTests { assertThat(deserializedUsage.getTotalTokens()).isEqualTo(0); } - @Test - void testDeserializationWithLegacyFormat() throws Exception { - String json = "{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}"; - DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); - assertThat(usage.getPromptTokens()).isEqualTo(100); - assertThat(usage.getCompletionTokens()).isEqualTo(50); - assertThat(usage.getTotalTokens()).isEqualTo(150); - } - @Test void testDeserializationWithDifferentPropertyOrder() throws Exception { - String json = "{\"totalTokens\":150,\"generationTokens\":50,\"completionTokens\":50,\"promptTokens\":100}"; + String json = "{\"totalTokens\":150,\"completionTokens\":50,\"promptTokens\":100}"; DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class); assertThat(usage.getPromptTokens()).isEqualTo(100); assertThat(usage.getCompletionTokens()).isEqualTo(50); @@ -150,7 +137,7 @@ public class DefaultUsageTests { DefaultUsage usage = new DefaultUsage(100, 50, 150, customNativeUsage); String json = this.objectMapper.writeValueAsString(usage); assertThat(json).isEqualTo( - "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50,\"nativeUsage\":{\"custom_field\":\"custom_value\",\"custom_number\":42}}"); + "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"nativeUsage\":{\"custom_field\":\"custom_value\",\"custom_number\":42}}"); } @Test @@ -195,14 +182,6 @@ public class DefaultUsageTests { assertThat(deserializedMap.get("field5")).isEqualTo(java.util.Map.of("nested", "value")); } - @Test - @SuppressWarnings("deprecation") - void testDeprecatedGenerationTokens() { - DefaultUsage usage = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150)); - assertThat(usage.getGenerationTokens()).isEqualTo(50L); - assertThat(usage.getCompletionTokens().longValue()).isEqualTo(usage.getGenerationTokens()); - } - @Test void testEqualsAndHashCode() { DefaultUsage usage1 = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150)); @@ -262,8 +241,7 @@ public class DefaultUsageTests { assertThat(usage.getTotalTokens()).isEqualTo(-3); String json = this.objectMapper.writeValueAsString(usage); - assertThat(json) - .isEqualTo("{\"promptTokens\":-1,\"completionTokens\":-2,\"totalTokens\":-3,\"generationTokens\":-2}"); + assertThat(json).isEqualTo("{\"promptTokens\":-1,\"completionTokens\":-2,\"totalTokens\":-3}"); } @Test diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java index f07591877..50bcfb21c 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/ToolCallingManagerTests.java @@ -73,18 +73,6 @@ public class ToolCallingManagerTests { runExplicitToolCallingExecutionWithOptions(chatOptions, prompt); } - @Test - void explicitToolCallingExecutionWithLegacyOptions() { - ChatOptions chatOptions = FunctionCallingOptions.builder() - .functionCallbacks(ToolCallbacks.from(tools)) - .proxyToolCalls(true) - .build(); - Prompt prompt = new Prompt( - new UserMessage("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")), - chatOptions); - runExplicitToolCallingExecutionWithOptions(chatOptions, prompt); - } - @Test void explicitToolCallingExecutionWithNewOptionsStream() { ChatOptions chatOptions = ToolCallingChatOptions.builder() diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java index 6047f804e..746529d82 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiChatProperties.java @@ -35,7 +35,7 @@ public class MistralAiChatProperties extends MistralAiParentProperties { public static final String CONFIG_PREFIX = "spring.ai.mistralai.chat"; - public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(); + public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.SMALL.getValue(); private static final Double DEFAULT_TEMPERATURE = 0.7; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java index bedb41f5d..04d50fa78 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/minimax/FunctionCallbackWithPlainFunctionBeanIT.java @@ -35,6 +35,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.minimax.MiniMaxChatModel; import org.springframework.ai.minimax.MiniMaxChatOptions; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -97,11 +98,9 @@ class FunctionCallbackWithPlainFunctionBeanIT { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .function("weatherFunction") - .build(); + ToolCallingChatOptions toolOptions = ToolCallingChatOptions.builder().toolNames("weatherFunction").build(); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), toolOptions)); logger.info("Response: {}", response); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java index a9beae2cf..9700c173b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/moonshot/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -34,6 +34,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.moonshot.MoonshotChatModel; import org.springframework.ai.moonshot.MoonshotChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -98,11 +99,9 @@ class FunctionCallbackWithPlainFunctionBeanIT { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius"); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .function("weatherFunction") - .build(); + ToolCallingChatOptions toolOptions = ToolCallingChatOptions.builder().toolNames("weatherFunction").build(); - ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); + ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), toolOptions)); logger.info("Response: {}", response); }); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java index eb31e3b88..2bf205545 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vertexai/gemini/tool/FunctionCallWithFunctionBeanIT.java @@ -17,6 +17,7 @@ package org.springframework.ai.autoconfigure.vertexai.gemini.tool; import java.util.List; +import java.util.Set; import java.util.function.Function; import org.junit.jupiter.api.Test; @@ -29,6 +30,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -109,14 +111,14 @@ class FunctionCallWithFunctionBeanIT { """); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().function("weatherFunction").build())); + ToolCallingChatOptions.builder().toolNames("weatherFunction").build())); logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), - VertexAiGeminiChatOptions.builder().function("weatherFunction3").build())); + VertexAiGeminiChatOptions.builder().toolNames(Set.of("weatherFunction3")).build())); logger.info("Response: {}", response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 714ea619c..d244ff77e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/zhipuai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -34,6 +34,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.zhipuai.ZhiPuAiChatModel; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; @@ -97,8 +98,8 @@ class FunctionCallbackWithPlainFunctionBeanIT { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .function("weatherFunction") + ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder() + .toolNames("weatherFunction") .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIT.java b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIT.java index 232417121..32fce1374 100644 --- a/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIT.java +++ b/spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/mongo/MongoDbAtlasLocalContainerConnectionDetailsFactoryIT.java @@ -100,7 +100,7 @@ class MongoDbAtlasLocalContainerConnectionDetailsFactoryIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java index 7d44383e1..88ef6a80c 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/BasicAuthChromaWhereIT.java @@ -121,7 +121,7 @@ public class BasicAuthChromaWhereIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java index a34973745..9ab4bc3ac 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreIT.java @@ -279,7 +279,7 @@ public class ChromaVectorStoreIT extends BaseVectorStoreTests { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java index 8f167669a..ef28e2684 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/ChromaVectorStoreObservationIT.java @@ -186,7 +186,7 @@ public class ChromaVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java index bd13afc9a..00997ae4f 100644 --- a/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/chroma/vectorstore/TokenSecuredChromaWhereIT.java @@ -158,7 +158,7 @@ public class TokenSecuredChromaWhereIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java index 34d64a1f6..14d037f6f 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreIT.java @@ -505,7 +505,7 @@ class ElasticsearchVectorStoreIT extends BaseVectorStoreTests { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } @Bean diff --git a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreObservationIT.java b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreObservationIT.java index ac951b1af..5efeeb1dd 100644 --- a/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-elasticsearch-store/src/test/java/org/springframework/ai/vectorstore/elasticsearch/ElasticsearchVectorStoreObservationIT.java @@ -220,7 +220,7 @@ public class ElasticsearchVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } @Bean diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java index 8f0f1d488..59105b516 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaCloudVectorStoreIT.java @@ -128,7 +128,7 @@ public class HanaCloudVectorStoreIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java index fe316738a..bd004a74c 100644 --- a/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-hanadb-store/src/test/java/org/springframework/ai/vectorstore/hanadb/HanaVectorStoreObservationIT.java @@ -206,7 +206,7 @@ public class HanaVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java index dff67b787..6a79ec2a1 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreCustomNamesIT.java @@ -252,7 +252,7 @@ public class MariaDBStoreCustomNamesIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java index 8f5f8de11..d4d1e8edb 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreIT.java @@ -451,7 +451,7 @@ public class MariaDBStoreIT extends BaseVectorStoreTests { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java index aca149114..616078dae 100644 --- a/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java +++ b/vector-stores/spring-ai-mariadb-store/src/test/java/org/springframework/ai/vectorstore/mariadb/MariaDBStoreObservationIT.java @@ -203,7 +203,7 @@ public class MariaDBStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreCustomFieldNamesIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreCustomFieldNamesIT.java index 87de9b841..14a842244 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreCustomFieldNamesIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreCustomFieldNamesIT.java @@ -252,7 +252,7 @@ class MilvusVectorStoreCustomFieldNamesIT { @Bean EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java index 0b02e98ec..50c5a64c4 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java @@ -361,7 +361,7 @@ public class MilvusVectorStoreIT extends BaseVectorStoreTests { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); // return new OpenAiEmbeddingModel(new // OpenAiApi(System.getenv("OPENAI_API_KEY")), MetadataMode.EMBED, // OpenAiEmbeddingOptions.builder().withModel("text-embedding-ada-002").build()); diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreObservationIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreObservationIT.java index 39791d758..af0d22544 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreObservationIT.java @@ -190,7 +190,7 @@ public class MilvusVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStoreIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStoreIT.java index e741f953f..1f61bdf09 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStoreIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDBAtlasVectorStoreIT.java @@ -345,7 +345,7 @@ class MongoDBAtlasVectorStoreIT extends BaseVectorStoreTests { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } @Bean diff --git a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDbVectorStoreObservationIT.java b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDbVectorStoreObservationIT.java index 527328dd2..20bf3db36 100644 --- a/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDbVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-mongodb-atlas-store/src/test/java/org/springframework/ai/vectorstore/mongodb/atlas/MongoDbVectorStoreObservationIT.java @@ -209,7 +209,7 @@ public class MongoDbVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } @Bean diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreIT.java index 760ce6293..d4bc5347c 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreIT.java @@ -374,7 +374,7 @@ class Neo4jVectorStoreIT extends BaseVectorStoreTests { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreObservationIT.java b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreObservationIT.java index 2facda316..8ce003d72 100644 --- a/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-neo4j-store/src/test/java/org/springframework/ai/vectorstore/neo4j/Neo4jVectorStoreObservationIT.java @@ -194,7 +194,7 @@ public class Neo4jVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java index b401983a0..380d434c6 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreIT.java @@ -602,7 +602,7 @@ class OpenSearchVectorStoreIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreObservationIT.java b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreObservationIT.java index 0bab8f183..13bf2eea1 100644 --- a/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-opensearch-store/src/test/java/org/springframework/ai/vectorstore/opensearch/OpenSearchVectorStoreObservationIT.java @@ -226,7 +226,7 @@ public class OpenSearchVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreCustomNamesIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreCustomNamesIT.java index 1c839572f..568447438 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreCustomNamesIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreCustomNamesIT.java @@ -245,7 +245,7 @@ public class PgVectorStoreCustomNamesIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java index 20d0a9818..e0b2bc357 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java @@ -483,7 +483,7 @@ public class PgVectorStoreIT extends BaseVectorStoreTests { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreObservationIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreObservationIT.java index 6fe478010..88cbe5752 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreObservationIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreObservationIT.java @@ -214,7 +214,7 @@ public class PgVectorStoreObservationIT { @Bean public EmbeddingModel embeddingModel() { - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build()); } }