From f36b73cce2827688548cf8c436f742881c4c8194 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Sun, 26 Jan 2025 16:39:27 +0100 Subject: [PATCH] refactor: migrate from functions to tools terminology Refactors the test codebase to use tools instead of functions. - Rename FunctionCallback to FunctionToolCallback - Rename FunctionCallingOptions to ToolCallingChatOptions - Update API methods from functions() to tools() - Deprecate function-related methods in favor of tool alternatives - Refactor MethodToolCallback implementation with improved builder pattern - Update all tests to use new tool-based APIs - Add funcs to tools migration guide Signed-off-by: Christian Tzolov --- FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md | 234 ++++++++++++++++++ .../ai/anthropic/AnthropicChatModelIT.java | 15 +- .../client/AnthropicChatClientIT.java | 14 +- ...lientMethodInvokingFunctionCallbackIT.java | 130 +++++++--- .../AzureOpenAiChatModelFunctionCallIT.java | 17 +- .../converse/BedrockConverseChatClientIT.java | 26 +- .../client/BedrockNovaChatClientIT.java | 9 +- .../ai/mistralai/MistralAiChatClientIT.java | 11 +- .../ai/mistralai/MistralAiChatModelIT.java | 13 +- .../OllamaChatModelFunctionCallingIT.java | 8 +- .../chat/client/OpenAiChatClientIT.java | 19 +- ...lientMethodInvokingFunctionCallbackIT.java | 82 ++++-- ...enAiChatClientMultipleFunctionCallsIT.java | 20 +- .../OpenAiChatClientProxyFunctionCallsIT.java | 2 +- ...texAiGeminiChatModelFunctionCallingIT.java | 67 +---- .../VertexAiGeminiPaymentTransactionIT.java | 5 +- .../ai/chat/client/ChatClient.java | 15 +- .../ai/chat/client/DefaultChatClient.java | 32 ++- .../chat/client/DefaultChatClientBuilder.java | 10 +- .../ai/tool/definition/ToolDefinition.java | 14 +- .../ai/tool/method/MethodToolCallback.java | 8 +- .../chat/client/DefaultChatClientTests.java | 6 +- .../MethodToolCallbackProviderTests.java | 14 +- .../tests/tool/FunctionToolCallbackTests.java | 10 +- .../tests/tool/MethodToolCallbackTests.java | 8 +- .../FunctionCallWithPromptFunctionIT.java | 12 +- .../tool/FunctionCallWithFunctionBeanIT.java | 4 +- .../FunctionCallWithFunctionWrapperIT.java | 8 +- .../FunctionCallWithPromptFunctionIT.java | 12 +- .../tool/FunctionCallWithFunctionBeanIT.java | 8 +- .../FunctionCallWithPromptFunctionIT.java | 16 +- .../mistralai/tool/PaymentStatusPromptIT.java | 6 +- .../tool/WeatherServicePromptIT.java | 19 +- .../tool/FunctionCallbackInPromptIT.java | 10 +- .../ollama/tool/OllamaFunctionCallbackIT.java | 12 +- .../tool/FunctionCallbackInPrompt2IT.java | 26 +- .../tool/FunctionCallbackInPromptIT.java | 22 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 10 +- .../tool/OpenAiFunctionCallback2IT.java | 8 +- .../openai/tool/OpenAiFunctionCallbackIT.java | 8 +- 40 files changed, 624 insertions(+), 346 deletions(-) create mode 100644 FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md diff --git a/FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md b/FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md new file mode 100644 index 000000000..0c1019461 --- /dev/null +++ b/FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md @@ -0,0 +1,234 @@ +# Migrating from FunctionCallback to ToolCallback API + +This guide helps you migrate from the deprecated FunctionCallback API to the new ToolCallback API in Spring AI. + +## Overview of Changes + +The Spring AI project is moving from "functions" to "tools" terminology to better align with industry standards. This involves several API changes while maintaining backward compatibility through deprecated methods. + +## Key Changes + +1. `FunctionCallback` → `ToolCallback` +2. `FunctionCallback.builder().functions()` → `FunctionToolCallback.builder()` +3. `FunctionCallback.builder().method()` → `MethodToolCallback.builder()` +4. `FunctionCallingOptions` → `ToolCallingChatOptions` +5. Method names from `functions()` → `tools()` + +## Migration Examples + +### 1. Basic Function Callback + +Before: +```java +FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build() +``` + +After: +```java +FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build() +``` + +### 2. ChatClient Usage + +Before: +```java +String response = ChatClient.create(chatModel) + .prompt() + .user("What's the weather like in San Francisco?") + .functions(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) + .call() + .content(); +``` + +After: +```java +String response = ChatClient.create(chatModel) + .prompt() + .user("What's the weather like in San Francisco?") + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) + .call() + .content(); +``` + +### 3. Method-Based Function Callbacks + +Before: +```java +FunctionCallback.builder() + .method("getWeatherInLocation", String.class, Unit.class) + .description("Get the weather in location") + .targetClass(TestFunctionClass.class) + .build() +``` + +After: +```java +var toolMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "getWeatherInLocation", String.class, Unit.class); + +MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .build() +``` + +And you can use the same `ChatClient#tools()` API to register method-based tool callbackes: + +```java +String response = ChatClient.create(chatModel) + .prompt() + .user("What's the weather like in San Francisco?") + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .build()) + .call() + .content(); +``` + +### 4. Options Configuration + +Before: +```java +FunctionCallingOptions.builder() + .model(modelName) + .function("weatherFunction") + .build() +``` + +After: +```java +ToolCallingChatOptions.builder() + .model(modelName) + .tools("weatherFunction") + .build() +``` + +### 5. Default Functions in ChatClient Builder + +Before: +```java +ChatClient.builder(chatModel) + .defaultFunctions(FunctionCallback.builder() + .function("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) + .build() +``` + +After: +```java +ChatClient.builder(chatModel) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) + .build() +``` + +### 6. Spring Bean Configuration + +Before: +```java +@Bean +public FunctionCallback weatherFunctionInfo() { + return FunctionCallback.builder() + .function("WeatherInfo", new MockWeatherService()) + .description("Get the current weather") + .inputType(MockWeatherService.Request.class) + .build(); +} +``` + +After: +```java +@Bean +public ToolCallback weatherFunctionInfo() { + return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService()) + .description("Get the current weather") + .inputType(MockWeatherService.Request.class) + .build(); +} +``` + +## Breaking Changes + +1. The `method()` configuration in function callbacks has been replaced with a more explicit method tool configuration using `ToolDefinition` and `MethodToolCallback`. + +2. When using method-based callbacks, you now need to explicitly find the method using `ReflectionUtils` and provide it to the builder. + +3. For non-static methods, you must now provide both the method and the target object: +```java +MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Description") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) + .build() +``` + +## Deprecated Methods + +The following methods are deprecated and will be removed in a future release: + +- `ChatClient.Builder.defaultFunctions(String...)` +- `ChatClient.Builder.defaultFunctions(FunctionCallback...)` +- `ChatClient.RequestSpec.functions()` + +Use their `tools` counterparts instead. + +## @Tool tool definition path. + +Now you can use the method-level annothation (`@Tool`) to register tools with Spring AI + +```java +public class Home { + + @Tool(description = "Turn light On or Off in a room.") + public void turnLight(String roomName, boolean on) { + // ... + logger.info("Turn light in room: {} to: {}", roomName, on); + } +} + +Home homeAutomation = new HomeAutomation(); + +String response = ChatClient.create(this.chatModel).prompt() + .user("Turn the light in the living room On.") + .tools(homeAutomation) + .call() + .content(); + +``` + + +## Additional Notes + +1. The new API provides better separation between tool definition and implementation. +2. Tool definitions can be reused across different implementations. +3. The builder pattern has been simplified for common use cases. +4. Better support for method-based tools with improved error handling. + +## Timeline + +The deprecated methods will be maintained for backward compatibility in the current major version but will be removed in the next major release. It's recommended to migrate to the new API as soon as possible. diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index bda5288be..dca4c7ddb 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -47,8 +47,8 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.model.Media; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.SpringBootConfiguration; @@ -258,7 +258,7 @@ class AnthropicChatModelIT { List.of(new Media(new MimeType("application", "pdf"), pdfData))); var response = this.chatModel.call(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()).build())); + ToolCallingChatOptions.builder().model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()).build())); assertThat(response.getResult().getOutput().getText()).containsAnyOf("Spring AI", "portable API"); } @@ -273,8 +273,7 @@ class AnthropicChatModelIT { var promptOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) @@ -306,8 +305,7 @@ class AnthropicChatModelIT { var promptOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) @@ -337,8 +335,7 @@ class AnthropicChatModelIT { var promptOptions = AnthropicChatOptions.builder() .model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") .inputType(MockWeatherService.Request.class) diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index 7d65d2396..f03a24ccc 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -40,7 +40,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; @@ -211,8 +211,7 @@ class AnthropicChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .inputType(MockWeatherService.Request.class) .build()) .call() @@ -230,8 +229,7 @@ class AnthropicChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeatherInLocation", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeatherInLocation", new MockWeatherService()) .inputType(MockWeatherService.Request.class) .build()) .call() @@ -248,8 +246,7 @@ class AnthropicChatClientIT { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunctions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -271,8 +268,7 @@ class AnthropicChatClientIT { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java index 751efab3c..191a2679e 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java @@ -29,11 +29,14 @@ import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ToolContext; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.log.LogAccessor; import org.springframework.test.context.ActiveProfiles; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -41,6 +44,7 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429") @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") @ActiveProfiles("logging-test") +@SuppressWarnings("null") class AnthropicChatClientMethodInvokingFunctionCallbackIT { private static final LogAccessor logger = new LogAccessor( @@ -57,11 +61,14 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { void methodGetWeatherGeneratedDescription() { // @formatter:off + var toolMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "getWeatherInLocation", String.class, Unit.class); + String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherInLocation", String.class, Unit.class) - .targetClass(TestFunctionClass.class) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod).build()) + .toolMethod(toolMethod) .build()) .call() .content(); @@ -76,12 +83,16 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { void methodGetWeatherStatic() { // @formatter:off + var toolMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "getWeatherStatic", String.class, Unit.class); + String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherStatic", String.class, Unit.class) - .description("Get the weather in location") - .targetClass(TestFunctionClass.class) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) .build()) .call() .content(); @@ -98,12 +109,18 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); // @formatter:off + + var turnLightMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "turnLight", String.class, boolean.class); + String response = ChatClient.create(this.chatModel).prompt() .user("Turn light on in the living room.") - .functions(FunctionCallback.builder() - .method("turnLight", String.class, boolean.class) - .description("Turn light on in the living room.") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(turnLightMethod) + .description("Turn light on in the living room.") + .build()) + .toolMethod(turnLightMethod) + .toolObject(targetObject) .build()) .call() .content(); @@ -121,12 +138,17 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); // @formatter:off + var toolMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "getWeatherNonStatic", String.class, Unit.class); + String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherNonStatic", String.class, Unit.class) - .description("Get the weather in location") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .call() .content(); @@ -143,17 +165,21 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); // @formatter:off + var toolMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "getWeatherWithContext", String.class, Unit.class, ToolContext.class); + String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherWithContext", String.class, Unit.class, ToolContext.class) - .description("Get the weather in location") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .toolContext(Map.of("tool", "value")) .call() .content(); - // @formatter:on logger.info("Response: " + response); @@ -170,18 +196,23 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); // @formatter:off + var toolMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "getWeatherNonStatic", String.class, Unit.class); + assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherNonStatic", String.class, Unit.class) - .description("Get the weather in location") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .toolContext(Map.of("tool", "value")) .call() .content()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Configured method does not accept ToolContext as input parameter!"); + .hasMessage("ToolContext is not supported by the method as an argument"); // @formatter:on } @@ -190,13 +221,18 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); - // @formatter:off + // @formatter:off + var toolMethod = ReflectionUtils.findMethod( + TestFunctionClass.class, "turnLivingRoomLightOn"); + String response = ChatClient.create(this.chatModel).prompt() .user("Turn light on in the living room.") - .functions(FunctionCallback.builder() - .method("turnLivingRoomLightOn") - .description("Can turn lights on in the Living Room") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolMethod(toolMethod) + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Can turn lights on in the Living Room") + .build()) + .toolObject(targetObject) .build()) .call() .content(); @@ -207,6 +243,25 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { assertThat(arguments).containsEntry("turnLivingRoomLightOn", true); } + @Test + void toolAnnotation() { + + TestFunctionClass targetObject = new TestFunctionClass(); + + // @formatter:off + String response = ChatClient.create(this.chatModel).prompt() + .user("Turn light red in the living room.") + .tools(targetObject) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + + assertThat(arguments).containsEntry("roomName", "living room") + .containsEntry("color", TestFunctionClass.LightColor.RED); + } + @Autowired ChatModel chatModel; @@ -270,6 +325,19 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT { arguments.put("turnLivingRoomLightOn", true); } + enum LightColor { + + RED, GREEN, BLUE + + } + + @Tool(description = "Change the lamp color in a room.") + public void changeRoomLightColor(String roomName, LightColor color) { + arguments.put("roomName", roomName); + arguments.put("color", color); + logger.info("Change light colur in room: {} to color: {}", roomName, color); + } + } } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java index 680a30832..e25b7a253 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatModelFunctionCallIT.java @@ -38,7 +38,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -69,8 +69,7 @@ class AzureOpenAiChatModelFunctionCallIT { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -98,8 +97,7 @@ class AzureOpenAiChatModelFunctionCallIT { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -120,8 +118,7 @@ class AzureOpenAiChatModelFunctionCallIT { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -158,8 +155,7 @@ class AzureOpenAiChatModelFunctionCallIT { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) @@ -185,8 +181,7 @@ class AzureOpenAiChatModelFunctionCallIT { var promptOptions = AzureOpenAiChatOptions.builder() .deploymentName(this.selectedModel) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 99d9ac2e2..93edb429e 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -34,8 +34,8 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; @@ -210,8 +210,7 @@ class BedrockConverseChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel) .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -230,8 +229,7 @@ class BedrockConverseChatClientIT { // @formatter:off ChatResponse response = ChatClient.create(this.chatModel) .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -265,8 +263,7 @@ class BedrockConverseChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel) .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -285,8 +282,7 @@ class BedrockConverseChatClientIT { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunctions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -308,8 +304,7 @@ class BedrockConverseChatClientIT { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -350,8 +345,7 @@ class BedrockConverseChatClientIT { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in Paris? Return the temperature in Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -371,7 +365,7 @@ class BedrockConverseChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() - .options(FunctionCallingOptions.builder().model(modelName).build()) + .options(ToolCallingChatOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) .call() @@ -393,7 +387,7 @@ class BedrockConverseChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to - .options(FunctionCallingOptions.builder().model(modelName).build()) + .options(ToolCallingChatOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) .call() .content(); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java index b62a87df4..b3d6e0bed 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java @@ -29,8 +29,8 @@ import org.springframework.ai.bedrock.converse.RequiresAwsCredentials; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.model.Media; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -145,8 +145,7 @@ public class BedrockNovaChatClientIT { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", (WeatherRequest request) -> { + .tools(FunctionToolCallback.builder("getCurrentWeather", (WeatherRequest request) -> { if (request.location().contains("Paris")) { return new WeatherResponse(15, request.unit()); } @@ -183,7 +182,7 @@ public class BedrockNovaChatClientIT { .withCredentialsProvider(EnvironmentVariableCredentialsProvider.create()) .withRegion(Region.US_EAST_1) .withTimeout(Duration.ofSeconds(120)) - .withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build()) + .withDefaultOptions(ToolCallingChatOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java index 5ee425793..fba261f81 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatClientIT.java @@ -32,7 +32,7 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; @@ -224,8 +224,7 @@ class MistralAiChatClientIT { String response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).toolChoice(ToolChoice.AUTO).build()) .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")) - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -246,8 +245,7 @@ class MistralAiChatClientIT { // @formatter:off String response = ChatClient.builder(this.chatModel) .defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build()) - .defaultFunctions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -270,8 +268,7 @@ class MistralAiChatClientIT { Flux response = ChatClient.create(this.chatModel).prompt() .options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build()) .user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java index 251ae548e..03741aabd 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatModelIT.java @@ -47,6 +47,7 @@ import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; @@ -97,7 +98,8 @@ class MistralAiChatModelIT { "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."); SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); - // NOTE: Mistral expects the system message to be before the user message or will + // NOTE: Mistral expects the system message to be before the user message or + // will // fail with 400 error. Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); ChatResponse response = this.chatModel.call(prompt); @@ -202,8 +204,7 @@ class MistralAiChatModelIT { var promptOptions = MistralAiChatOptions.builder() .model(MistralAiApi.ChatModel.SMALL.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -228,8 +229,7 @@ class MistralAiChatModelIT { var promptOptions = MistralAiChatOptions.builder() .model(MistralAiApi.ChatModel.SMALL.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) @@ -316,8 +316,7 @@ class MistralAiChatModelIT { var promptOptions = MistralAiChatOptions.builder() .model(MistralAiApi.ChatModel.SMALL.getValue()) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java index 3d8353aec..04a8cd5b8 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java @@ -30,10 +30,10 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.tool.MockWeatherService; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -61,8 +61,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT { var promptOptions = OllamaOptions.builder() .model(MODEL) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) @@ -85,8 +84,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT { var promptOptions = OllamaOptions.builder() .model(MODEL) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java index cf152c687..60434c740 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientIT.java @@ -35,13 +35,13 @@ import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.ParameterizedTypeReference; @@ -245,16 +245,13 @@ class OpenAiChatClientIT extends AbstractIT { @Test void functionCallTest() { - FunctionCallback functionCallback = FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build(); - // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) - .functions(functionCallback) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) .call() .content(); // @formatter:on @@ -269,8 +266,7 @@ class OpenAiChatClientIT extends AbstractIT { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunctions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -290,8 +286,7 @@ class OpenAiChatClientIT extends AbstractIT { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java index ec78c50c3..112133a90 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java @@ -26,12 +26,14 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ToolContext; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.log.LogAccessor; import org.springframework.test.context.ActiveProfiles; +import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -55,13 +57,18 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT { @Test void methodGetWeatherStatic() { + + var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherStatic", String.class, + Unit.class); + // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherStatic", String.class, Unit.class) - .description("Get the weather in location") - .targetClass(TestFunctionClass.class) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) .build()) .call() .content(); @@ -77,13 +84,17 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); + var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLight", String.class, boolean.class); + // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("Turn light on in the living room.") - .functions(FunctionCallback.builder() - .method("turnLight", String.class, boolean.class) - .description("Can turn lights on or off by room name") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Can turn lights on or off by room name") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .call() .content(); @@ -100,13 +111,18 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); + var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class, + Unit.class); + // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherNonStatic", String.class, Unit.class) - .description("Get the weather in location") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .call() .content(); @@ -122,13 +138,18 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); + var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherWithContext", String.class, + Unit.class, ToolContext.class); + // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherWithContext", String.class, Unit.class, ToolContext.class) - .description("Get the weather in location") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .toolContext(Map.of("tool", "value")) .call() @@ -146,19 +167,24 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); + var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class, + Unit.class); + // @formatter:off assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") - .functions(FunctionCallback.builder() - .method("getWeatherNonStatic", String.class, Unit.class) - .description("Get the weather in location") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Get the weather in location") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .toolContext(Map.of("tool", "value")) .call() .content()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Configured method does not accept ToolContext as input parameter!"); + .hasMessage("ToolContext is not supported by the method as an argument"); // @formatter:on } @@ -167,13 +193,17 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT { TestFunctionClass targetObject = new TestFunctionClass(); + var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLivingRoomLightOn"); + // @formatter:off String response = ChatClient.create(this.chatModel).prompt() .user("Turn light on in the living room.") - .functions(FunctionCallback.builder() - .method("turnLivingRoomLightOn") - .description("Can turn lights on in the Living Room") - .targetObject(targetObject) + .tools(MethodToolCallback.builder() + .toolDefinition(ToolDefinition.builder(toolMethod) + .description("Can turn lights on in the Living Room") + .build()) + .toolMethod(toolMethod) + .toolObject(targetObject) .build()) .call() .content(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java index e4f9b8502..715213a7e 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMultipleFunctionCallsIT.java @@ -29,12 +29,12 @@ import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ToolContext; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiTestConfiguration; import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.api.tool.MockWeatherService.Request; import org.springframework.ai.openai.api.tool.MockWeatherService.Response; import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.io.Resource; @@ -83,8 +83,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { // @formatter:off response = chatClientBuilder.build().prompt() .user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?")) - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -114,8 +113,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunctions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -157,8 +155,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunctions(FunctionCallback.builder() - .function("getCurrentWeather", biFunction) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", biFunction) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -201,8 +198,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { // @formatter:off String response = ChatClient.builder(this.chatModel) - .defaultFunctions(FunctionCallback.builder() - .function("getCurrentWeather", biFunction) + .defaultTools(FunctionToolCallback.builder("getCurrentWeather", biFunction) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -224,8 +220,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { // @formatter:off Flux response = ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) + .tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -254,8 +249,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT { String content = chatClient.prompt() .user("What's the weather like in Shanghai?") - .functions(FunctionCallback.builder() - .function("currentTemp", function) + .tools(FunctionToolCallback.builder("currentTemp", function) .description("get current temp") .inputType(MyFunction.Req.class) .build()) diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java index a173eb755..67ea94ece 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java @@ -122,7 +122,7 @@ class OpenAiChatClientProxyFunctionCallsIT extends AbstractIT { chatResponse = chatClient.prompt() .messages(messages) - .functions(this.functionDefinition) + .tools(this.functionDefinition) .options(OpenAiChatOptions.builder().proxyToolCalls(true).build()) .call() .chatResponse(); diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java index de41a33b5..0e61e6d85 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatModelFunctionCallingIT.java @@ -33,8 +33,8 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallback.SchemaType; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.util.json.JsonSchemaGenerator; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions; import org.springframework.beans.factory.annotation.Autowired; @@ -82,10 +82,9 @@ public class VertexAiGeminiChatModelFunctionCallingIT { """; var promptOptions = VertexAiGeminiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("get_current_weather", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) .description("Get the current weather in a given location") - .inputTypeSchema(openApiSchema) + .inputSchema(openApiSchema) .inputType(MockWeatherService.Request.class) .build())) .build(); @@ -100,46 +99,6 @@ public class VertexAiGeminiChatModelFunctionCallingIT { @Test public void functionCallTestInferredOpenApiSchema() { - UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius units."); - - List messages = new ArrayList<>(List.of(userMessage)); - - var promptOptions = VertexAiGeminiChatOptions.builder() - .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH) - .functionCallbacks(List.of( - FunctionCallback.builder() - .function("get_current_weather", new MockWeatherService()) - .schemaType(SchemaType.OPEN_API_SCHEMA) - .description("Get the current weather in a given location.") - .inputType(MockWeatherService.Request.class) - .build(), - FunctionCallback.builder() - .function("get_payment_status", new PaymentStatus()) - .schemaType(SchemaType.OPEN_API_SCHEMA) - .description( - "Retrieves the payment status for transaction. For example what is the payment status for transaction 700?") - .inputType(PaymentInfoRequest.class) - .build())) - .build(); - - ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); - - logger.info("Response: " + response); - - assertThat(response.getResult().getOutput().getText()).containsAnyOf("15.0", "15"); - - ChatResponse response2 = this.chatModel - .call(new Prompt("What is the payment status for transaction 696?", promptOptions)); - - logger.info("Response: " + response2); - - assertThat(response2.getResult().getOutput().getText()).containsIgnoringCase("transaction 696 is PAYED"); - - } - - @Test - public void functionCallTestInferredOpenApiSchema2() { - UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius."); @@ -148,15 +107,15 @@ public class VertexAiGeminiChatModelFunctionCallingIT { var promptOptions = VertexAiGeminiChatOptions.builder() .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH) .functionCallbacks(List.of( - FunctionCallback.builder() - .function("get_current_weather", new MockWeatherService()) - .schemaType(SchemaType.OPEN_API_SCHEMA) + FunctionToolCallback.builder("get_current_weather", new MockWeatherService()) + .inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES)) .description("Get the current weather in a given location.") .inputType(MockWeatherService.Request.class) .build(), - FunctionCallback.builder() - .function("get_payment_status", new PaymentStatus()) - .schemaType(SchemaType.OPEN_API_SCHEMA) + FunctionToolCallback.builder("get_payment_status", new PaymentStatus()) + .inputSchema(JsonSchemaGenerator.generateForType(PaymentInfoRequest.class, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES)) .description( "Retrieves the payment status for transaction. For example what is the payment status for transaction 700?") .inputType(PaymentInfoRequest.class) @@ -188,9 +147,9 @@ public class VertexAiGeminiChatModelFunctionCallingIT { var promptOptions = VertexAiGeminiChatOptions.builder() .model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("getCurrentWeather", new MockWeatherService()) - .schemaType(SchemaType.OPEN_API_SCHEMA) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES)) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build())) diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java index f6c6606a1..ceb19a266 100644 --- a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiPaymentTransactionIT.java @@ -69,7 +69,7 @@ public class VertexAiGeminiPaymentTransactionIT { // @formatter:off String content = this.chatClient.prompt() .advisors(new LoggingAdvisor()) - .functions("paymentStatus") + .tools("paymentStatus") .user(""" What is the status of my payment transactions 001, 002 and 003? If requred invoke the function per transaction. @@ -86,8 +86,7 @@ public class VertexAiGeminiPaymentTransactionIT { Flux streamContent = this.chatClient.prompt() .advisors(new LoggingAdvisor()) - .functions("paymentStatus") - // .functions("paymentStatuses") + .tools("paymentStatus") .user(""" What is the status of my payment transactions 001, 002 and 003? If requred invoke the function per transaction. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index e7d3d821f..c90c5ed46 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -218,10 +218,12 @@ public interface ChatClient { ChatClientRequestSpec tools(Object... toolObjects); - ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks); + // ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks); + @Deprecated ChatClientRequestSpec functions(FunctionCallback... functionCallbacks); + @Deprecated ChatClientRequestSpec functions(String... functionBeanNames); ChatClientRequestSpec toolContext(Map toolContext); @@ -281,10 +283,17 @@ public interface ChatClient { Builder defaultTools(Object... toolObjects); - Builder defaultToolCallbacks(FunctionCallback... toolCallbacks); - + /** + * @deprecated in favor of {@link #defaultTools(String...)} + */ + @Deprecated Builder defaultFunctions(String... functionNames); + /** + * @deprecated in favor of {@link #defaultTools(FunctionCallback...)} or + * {@link #defaultToolCallbacks(FunctionCallback...)} + */ + @Deprecated Builder defaultFunctions(FunctionCallback... functionCallbacks); Builder defaultToolContext(Map toolContext); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 0d3ea5f5c..c06452c5a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -849,22 +849,38 @@ public class DefaultChatClient implements ChatClient { public ChatClientRequestSpec tools(Object... toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null"); Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); - this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects))); + + List functionCallbacks = new ArrayList<>(); + List nonFunctinCallbacks = new ArrayList<>(); + for (Object toolObject : toolObjects) { + if (toolObject instanceof FunctionCallback) { + functionCallbacks.add((FunctionCallback) toolObject); + } + else { + nonFunctinCallbacks.add(toolObject); + } + } + this.functionCallbacks.addAll(functionCallbacks); + this.functionCallbacks.addAll(Arrays + .asList(ToolCallbacks.from(nonFunctinCallbacks.toArray(new Object[nonFunctinCallbacks.size()])))); return this; } - @Override - public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) { - Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); - this.functionCallbacks.addAll(Arrays.asList(toolCallbacks)); - return this; - } + // @Override + // public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) { + // Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + // Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null + // elements"); + // this.functionCallbacks.addAll(Arrays.asList(toolCallbacks)); + // return this; + // } + @Deprecated public ChatClientRequestSpec functions(String... functionBeanNames) { return tools(functionBeanNames); } + @Deprecated public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) { Assert.notNull(functionCallbacks, "functionCallbacks cannot be null"); Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 60312e9cc..d5a565990 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -157,13 +157,7 @@ public class DefaultChatClientBuilder implements Builder { @Override public Builder defaultTools(Object... toolObjects) { - this.defaultRequest.functions(ToolCallbacks.from(toolObjects)); - return this; - } - - @Override - public Builder defaultToolCallbacks(FunctionCallback... toolCallbacks) { - this.defaultRequest.functions(toolCallbacks); + this.defaultRequest.tools(toolObjects); return this; } @@ -201,7 +195,7 @@ public class DefaultChatClientBuilder implements Builder { void addToolCallbacks(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - this.defaultRequest.toolCallbacks(toolCallbacks.toArray(FunctionCallback[]::new)); + this.defaultRequest.tools(toolCallbacks.toArray(FunctionCallback[]::new)); } void addToolContext(Map toolContext) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java index d25367c5b..bee3ec03c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java @@ -52,14 +52,20 @@ public interface ToolDefinition { } /** - * Create a default {@link ToolDefinition} instance from a {@link Method}. + * Create a default {@link ToolDefinition} builder from a {@link Method}. */ - static ToolDefinition from(Method method) { + static DefaultToolDefinition.Builder builder(Method method) { return DefaultToolDefinition.builder() .name(ToolUtils.getToolName(method)) .description(ToolUtils.getToolDescription(method)) - .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)) - .build(); + .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)); + } + + /** + * Create a default {@link ToolDefinition} instance from a {@link Method}. + */ + static ToolDefinition from(Method method) { + return ToolDefinition.builder(method).build(); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index c59fef184..c4b153e02 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -38,6 +38,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; +import org.springframework.util.ReflectionUtils; /** * A {@link ToolCallback} implementation to invoke methods as tools. @@ -64,10 +65,11 @@ public class MethodToolCallback implements ToolCallback { private final ToolCallResultConverter toolCallResultConverter; public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod, - Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) { + @Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) { Assert.notNull(toolDefinition, "toolDefinition cannot be null"); Assert.notNull(toolMethod, "toolMethod cannot be null"); - Assert.notNull(toolObject, "toolObject cannot be null"); + Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null, + "toolObject cannot be null for non-static methods"); this.toolDefinition = toolDefinition; this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA; this.toolMethod = toolMethod; @@ -165,7 +167,7 @@ public class MethodToolCallback implements ToolCallback { } private boolean isObjectNotPublic() { - return !Modifier.isPublic(toolObject.getClass().getModifiers()); + return toolObject != null && !Modifier.isPublic(toolObject.getClass().getModifiers()); } private boolean isMethodNotPublic() { diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java index ecb58e86b..c29d87495 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java @@ -1373,9 +1373,9 @@ class DefaultChatClientTests { void whenToolCallbacksElementIsNullThenThrow() { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); - assertThatThrownBy(() -> spec.toolCallbacks(mock(ToolCallback.class), null)) + assertThatThrownBy(() -> spec.tools(mock(ToolCallback.class), null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("toolCallbacks cannot contain null elements"); + .hasMessage("toolObjects cannot contain null elements"); } @Test @@ -1383,7 +1383,7 @@ class DefaultChatClientTests { ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build(); ChatClient.ChatClientRequestSpec spec = chatClient.prompt(); ToolCallback toolCallback = mock(ToolCallback.class); - spec = spec.toolCallbacks(toolCallback); + spec = spec.tools(toolCallback); DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec; assertThat(defaultSpec.getFunctionCallbacks()).contains(toolCallback); } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java index edbe54118..7c0cc7060 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java @@ -55,15 +55,17 @@ class MethodToolCallbackProviderTests { assertThat(callbacks).hasSize(2); - var callback1 = Stream.of(callbacks).filter(c -> c.getName().equals("testMethod")).findFirst(); + var callback1 = Stream.of(callbacks).filter(c -> c.getToolDefinition().name().equals("testMethod")).findFirst(); assertThat(callback1).isPresent(); - assertThat(callback1.get().getName()).isEqualTo("testMethod"); - assertThat(callback1.get().getDescription()).isEqualTo("Test description"); + assertThat(callback1.get().getToolDefinition().name()).isEqualTo("testMethod"); + assertThat(callback1.get().getToolDefinition().description()).isEqualTo("Test description"); - var callback2 = Stream.of(callbacks).filter(c -> c.getName().equals("testStaticMethod")).findFirst(); + var callback2 = Stream.of(callbacks) + .filter(c -> c.getToolDefinition().name().equals("testStaticMethod")) + .findFirst(); assertThat(callback2).isPresent(); - assertThat(callback2.get().getName()).isEqualTo("testStaticMethod"); - assertThat(callback2.get().getDescription()).isEqualTo("Test description"); + assertThat(callback2.get().getToolDefinition().name()).isEqualTo("testStaticMethod"); + assertThat(callback2.get().getToolDefinition().description()).isEqualTo("Test description"); } @Test diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java index 244baa0be..4f1c6d0f1 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java @@ -76,7 +76,7 @@ public class FunctionToolCallbackTests { .build() .prompt() .user("Welcome the users to the library") - .toolCallbacks(FunctionToolCallback.builder("sayWelcome", (input) -> { + .tools(FunctionToolCallback.builder("sayWelcome", (input) -> { logger.info("CALLBACK - Welcoming users to the library"); }) .description("Welcome users to the library") @@ -105,8 +105,8 @@ public class FunctionToolCallbackTests { .build() .prompt() .user("Welcome %s to the library".formatted("James Bond")) - .toolCallbacks(FunctionToolCallback.builder("welcomeUser", (user) -> { - logger.info("CALLBACK - Welcoming "+ ((User) user).name() +" to the library"); + .tools(FunctionToolCallback.builder("welcomeUser", (user) -> { + logger.info("CALLBACK - Welcoming " + ((User) user).name() + " to the library"); }) .description("Welcome a specific user to the library") .inputType(User.class) @@ -141,7 +141,7 @@ public class FunctionToolCallbackTests { .build() .prompt() .user("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")) - .toolCallbacks(FunctionToolCallback.builder("availableBooksByAuthor", function) + .tools(FunctionToolCallback.builder("availableBooksByAuthor", function) .description("Get the list of books written by the given author available in the library") .inputType(Author.class) .build()) @@ -175,7 +175,7 @@ public class FunctionToolCallbackTests { .build() .prompt() .user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia")) - .toolCallbacks(FunctionToolCallback.builder("authorsByAvailableBooks", function) + .tools(FunctionToolCallback.builder("authorsByAvailableBooks", function) .description("Get the list of authors who wrote the given books available in the library") .inputType(Books.class) .build()) diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java index 3cc7a8fcb..c8175d902 100644 --- a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/MethodToolCallbackTests.java @@ -86,9 +86,9 @@ public class MethodToolCallbackTests { .call() .content(); assertThat(content).isNotEmpty() - .contains("The Hobbit") - .contains("The Lord of The Rings") - .contains("The Silmarillion"); + .containsIgnoringCase("The Hobbit") + .containsIgnoringCase("The Lord of The Rings") + .containsIgnoringCase("The Silmarillion"); } @Test @@ -109,7 +109,7 @@ public class MethodToolCallbackTests { .build() .prompt() .user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia")) - .toolCallbacks(ToolCallbacks.from(tools)) + .tools(ToolCallbacks.from(tools)) .call() .content(); assertThat(content).isNotEmpty().contains("J.R.R. Tolkien").contains("C.S. Lewis"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java index 332f1e55a..8ad18487b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/anthropic/tool/FunctionCallWithPromptFunctionIT.java @@ -28,7 +28,7 @@ import org.springframework.ai.autoconfigure.anthropic.AnthropicAutoConfiguration import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -57,11 +57,11 @@ public class FunctionCallWithPromptFunctionIT { "What's the weather like in San Francisco, in Paris and in Tokyo? Return the temperature in Celsius."); var promptOptions = AnthropicChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) - .description("Get the weather in location. Return temperature in 36°F or 36°C format.") - .inputType(MockWeatherService.Request.class) - .build())) + .functionCallbacks( + List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java index 1a09f9b63..f3d6747ab 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionBeanIT.java @@ -29,7 +29,7 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -95,7 +95,7 @@ class FunctionCallWithFunctionBeanIT { "What's the weather like in San Francisco, Paris and in Tokyo? Use Multi-turn function calling."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().function("weatherFunction").build())); + ToolCallingChatOptions.builder().tools("weatherFunction").build())); logger.info("Response: " + response); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java index 64c283185..05d5e53e7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithFunctionWrapperIT.java @@ -27,7 +27,8 @@ import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -76,10 +77,9 @@ public class FunctionCallWithFunctionWrapperIT { static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public ToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("WeatherInfo", new MockWeatherService()) + return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService()) .description("Get the current weather in a given location") .inputType(MockWeatherService.Request.class) .build(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java index 8d7c83998..7c17288a4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/tool/FunctionCallWithPromptFunctionIT.java @@ -27,7 +27,7 @@ import org.springframework.ai.azure.openai.AzureOpenAiChatOptions; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -60,11 +60,11 @@ public class FunctionCallWithPromptFunctionIT { "What's the weather like in San Francisco, in Paris and in Tokyo? Use Multi-turn function calling."); var promptOptions = AzureOpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) + .functionCallbacks( + List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java index f7eda1d82..364ef52f2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithFunctionBeanIT.java @@ -30,7 +30,7 @@ import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -63,14 +63,14 @@ class FunctionCallWithFunctionBeanIT { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().function("weatherFunction").build())); + ToolCallingChatOptions.builder().tools("weatherFunction").build())); logger.info("Response: " + response); assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().function("weatherFunction3").build())); + ToolCallingChatOptions.builder().tools("weatherFunction3").build())); logger.info("Response: " + response); @@ -92,7 +92,7 @@ class FunctionCallWithFunctionBeanIT { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); Flux responses = chatModel.stream(new Prompt(List.of(userMessage), - FunctionCallingOptions.builder().function("weatherFunction").build())); + ToolCallingChatOptions.builder().tools("weatherFunction").build())); String content = responses.collectList() .block() diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java index b9e6b59dc..b8f6433ed 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/tool/FunctionCallWithPromptFunctionIT.java @@ -27,8 +27,8 @@ import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -55,12 +55,12 @@ public class FunctionCallWithPromptFunctionIT { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, in Paris and in Tokyo? Return the temperature in Celsius."); - var promptOptions = FunctionCallingOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) - .description("Get the weather in location. Return temperature in 36°F or 36°C format.") - .inputType(MockWeatherService.Request.class) - .build())) + var promptOptions = ToolCallingChatOptions.builder() + .toolCallbacks( + List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location. Return temperature in 36°F or 36°C format.") + .inputType(MockWeatherService.Request.class) + .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java index 93b6b46b1..8967957a7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java @@ -30,7 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mistralai.MistralAiChatModel; import org.springframework.ai.mistralai.MistralAiChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -63,8 +63,8 @@ public class PaymentStatusPromptIT { UserMessage userMessage = new UserMessage("What's the status of my transaction with id T1001?"); var promptOptions = MistralAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("retrievePaymentStatus", + .functionCallbacks(List.of(FunctionToolCallback + .builder("retrievePaymentStatus", (Transaction transaction) -> new Status(DATA.get(transaction).status())) .description("Get payment status of a transaction") .inputType(Transaction.class) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java index 0fa7e93e4..d2a8028f0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java @@ -35,8 +35,8 @@ import org.springframework.ai.mistralai.MistralAiChatModel; import org.springframework.ai.mistralai.MistralAiChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -72,11 +72,11 @@ public class WeatherServicePromptIT { var promptOptions = MistralAiChatOptions.builder() .toolChoice(ToolChoice.AUTO) - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MyWeatherService()) - .description("Get the current weather in requested location") - .inputType(MyWeatherService.Request.class) - .build())) + .functionCallbacks( + List.of(FunctionToolCallback.builder("CurrentWeatherService", new MyWeatherService()) + .description("Get the current weather in requested location") + .inputType(MyWeatherService.Request.class) + .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); @@ -97,9 +97,8 @@ public class WeatherServicePromptIT { UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius."); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MyWeatherService()) + ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder() + .toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MyWeatherService()) .description("Get the current weather in requested location") .inputType(MyWeatherService.Request.class) .build())) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java index 23ad3daaa..6a51c0e66 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/FunctionCallbackInPromptIT.java @@ -30,9 +30,9 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -69,8 +69,8 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); var promptOptions = OllamaOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback + .builder("CurrentWeatherService", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) @@ -95,8 +95,8 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT { "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); var promptOptions = OllamaOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) + .functionCallbacks(List.of(FunctionToolCallback + .builder("CurrentWeatherService", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java index 736c9308c..f39e75dd0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/ollama/tool/OllamaFunctionCallbackIT.java @@ -30,10 +30,11 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -117,7 +118,7 @@ public class OllamaFunctionCallbackIT extends BaseOllamaIT { UserMessage userMessage = new UserMessage( "What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations."); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder().function("WeatherInfo").build(); + ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder().tools("WeatherInfo").build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); @@ -131,10 +132,9 @@ public class OllamaFunctionCallbackIT extends BaseOllamaIT { static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public ToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("WeatherInfo", new MockWeatherService()) + return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService()) .description( "Find the weather conditions, forecasts, and temperatures for a location, like a city or state.") .inputType(MockWeatherService.Request.class) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java index 19f87a3d9..d8714b0d6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPrompt2IT.java @@ -25,9 +25,9 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -58,12 +58,12 @@ public class FunctionCallbackInPrompt2IT { .call().content(); String content = ChatClient.builder(chatModel).build().prompt() - .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build()) + .user("What's the weather like in San Francisco, Tokyo, and Paris?") + .functions(FunctionToolCallback + .builder("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build()) .call().content(); // @formatter:on @@ -87,8 +87,8 @@ public class FunctionCallbackInPrompt2IT { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("Turn the light on in the kitchen and in the living room!") - .functions(FunctionCallback.builder() - .function("turnLight", (LightInfo lightInfo) -> { + .functions(FunctionToolCallback + .builder("turnLight", (LightInfo lightInfo) -> { logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName()); state.put(lightInfo.roomName(), lightInfo.isOn()); }) @@ -113,8 +113,8 @@ public class FunctionCallbackInPrompt2IT { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in Amsterdam?") - .functions(FunctionCallback.builder() - .function("CurrentWeatherService", input -> "18 degrees Celsius") + .functions(FunctionToolCallback + .builder("CurrentWeatherService", input -> "18 degrees Celsius") .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) @@ -137,8 +137,8 @@ public class FunctionCallbackInPrompt2IT { // @formatter:off String content = ChatClient.builder(chatModel).build().prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris?") - .functions(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) + .functions(FunctionToolCallback + .builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build()) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java index 539e137e1..b4195bf55 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java @@ -29,10 +29,10 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.core.log.LogAccessor; @@ -61,11 +61,11 @@ public class FunctionCallbackInPromptIT { "What's the weather like in San Francisco, Tokyo, and Paris?"); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) + .functionCallbacks( + List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) .build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions)); @@ -90,11 +90,11 @@ public class FunctionCallbackInPromptIT { "What's the weather like in San Francisco, Tokyo, and Paris?"); var promptOptions = OpenAiChatOptions.builder() - .functionCallbacks(List.of(FunctionCallback.builder() - .function("CurrentWeatherService", new MockWeatherService()) - .description("Get the weather in location") - .inputType(MockWeatherService.Request.class) - .build())) + .functionCallbacks( + List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) + .description("Get the weather in location") + .inputType(MockWeatherService.Request.class) + .build())) .build(); Flux response = chatModel.stream(new Prompt(List.of(userMessage), promptOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 039790be5..513d2dbfd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -38,7 +38,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; @@ -154,9 +154,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { UserMessage userMessage = new UserMessage( "Please schedule a train from San Francisco to Los Angeles on 2023-12-25"); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .function("trainReservation") - .build(); + ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder().tools("trainReservation").build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); @@ -265,9 +263,7 @@ class FunctionCallbackWithPlainFunctionBeanIT { // Test weatherFunction UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - FunctionCallingOptions functionOptions = FunctionCallingOptions.builder() - .function("weatherFunction") - .build(); + ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder().tools("weatherFunction").build(); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions)); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallback2IT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallback2IT.java index 49dfeca87..c361a8971 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallback2IT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallback2IT.java @@ -23,9 +23,10 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -92,10 +93,9 @@ public class OpenAiFunctionCallback2IT { static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public ToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("WeatherInfo", new MockWeatherService()) + return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallbackIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallbackIT.java index dbb6eea23..ffc799fe7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallbackIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/OpenAiFunctionCallbackIT.java @@ -29,10 +29,11 @@ import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -103,10 +104,9 @@ public class OpenAiFunctionCallbackIT { static class Config { @Bean - public FunctionCallback weatherFunctionInfo() { + public ToolCallback weatherFunctionInfo() { - return FunctionCallback.builder() - .function("WeatherInfo", new MockWeatherService()) + return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService()) .description("Get the weather in location") .inputType(MockWeatherService.Request.class) .build();