refactor: migrate from functions to tools terminology
Refactors the test codebase to use tools instead of functions. - Rename FunctionCallback to FunctionToolCallback - Rename FunctionCallingOptions to ToolCallingChatOptions - Update API methods from functions() to tools() - Deprecate function-related methods in favor of tool alternatives - Refactor MethodToolCallback implementation with improved builder pattern - Update all tests to use new tool-based APIs - Add funcs to tools migration guide Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
This commit is contained in:
234
FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md
Normal file
234
FUNCTIONS-TO-TOOLS-API-MIGRATION-GUIDE.md
Normal file
@@ -0,0 +1,234 @@
|
||||
# Migrating from FunctionCallback to ToolCallback API
|
||||
|
||||
This guide helps you migrate from the deprecated FunctionCallback API to the new ToolCallback API in Spring AI.
|
||||
|
||||
## Overview of Changes
|
||||
|
||||
The Spring AI project is moving from "functions" to "tools" terminology to better align with industry standards. This involves several API changes while maintaining backward compatibility through deprecated methods.
|
||||
|
||||
## Key Changes
|
||||
|
||||
1. `FunctionCallback` → `ToolCallback`
|
||||
2. `FunctionCallback.builder().functions()` → `FunctionToolCallback.builder()`
|
||||
3. `FunctionCallback.builder().method()` → `MethodToolCallback.builder()`
|
||||
4. `FunctionCallingOptions` → `ToolCallingChatOptions`
|
||||
5. Method names from `functions()` → `tools()`
|
||||
|
||||
## Migration Examples
|
||||
|
||||
### 1. Basic Function Callback
|
||||
|
||||
Before:
|
||||
```java
|
||||
FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()
|
||||
```
|
||||
|
||||
After:
|
||||
```java
|
||||
FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()
|
||||
```
|
||||
|
||||
### 2. ChatClient Usage
|
||||
|
||||
Before:
|
||||
```java
|
||||
String response = ChatClient.create(chatModel)
|
||||
.prompt()
|
||||
.user("What's the weather like in San Francisco?")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
```
|
||||
|
||||
After:
|
||||
```java
|
||||
String response = ChatClient.create(chatModel)
|
||||
.prompt()
|
||||
.user("What's the weather like in San Francisco?")
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
```
|
||||
|
||||
### 3. Method-Based Function Callbacks
|
||||
|
||||
Before:
|
||||
```java
|
||||
FunctionCallback.builder()
|
||||
.method("getWeatherInLocation", String.class, Unit.class)
|
||||
.description("Get the weather in location")
|
||||
.targetClass(TestFunctionClass.class)
|
||||
.build()
|
||||
```
|
||||
|
||||
After:
|
||||
```java
|
||||
var toolMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "getWeatherInLocation", String.class, Unit.class);
|
||||
|
||||
MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.build()
|
||||
```
|
||||
|
||||
And you can use the same `ChatClient#tools()` API to register method-based tool callbackes:
|
||||
|
||||
```java
|
||||
String response = ChatClient.create(chatModel)
|
||||
.prompt()
|
||||
.user("What's the weather like in San Francisco?")
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
```
|
||||
|
||||
### 4. Options Configuration
|
||||
|
||||
Before:
|
||||
```java
|
||||
FunctionCallingOptions.builder()
|
||||
.model(modelName)
|
||||
.function("weatherFunction")
|
||||
.build()
|
||||
```
|
||||
|
||||
After:
|
||||
```java
|
||||
ToolCallingChatOptions.builder()
|
||||
.model(modelName)
|
||||
.tools("weatherFunction")
|
||||
.build()
|
||||
```
|
||||
|
||||
### 5. Default Functions in ChatClient Builder
|
||||
|
||||
Before:
|
||||
```java
|
||||
ChatClient.builder(chatModel)
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.build()
|
||||
```
|
||||
|
||||
After:
|
||||
```java
|
||||
ChatClient.builder(chatModel)
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.build()
|
||||
```
|
||||
|
||||
### 6. Spring Bean Configuration
|
||||
|
||||
Before:
|
||||
```java
|
||||
@Bean
|
||||
public FunctionCallback weatherFunctionInfo() {
|
||||
return FunctionCallback.builder()
|
||||
.function("WeatherInfo", new MockWeatherService())
|
||||
.description("Get the current weather")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
After:
|
||||
```java
|
||||
@Bean
|
||||
public ToolCallback weatherFunctionInfo() {
|
||||
return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService())
|
||||
.description("Get the current weather")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build();
|
||||
}
|
||||
```
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
1. The `method()` configuration in function callbacks has been replaced with a more explicit method tool configuration using `ToolDefinition` and `MethodToolCallback`.
|
||||
|
||||
2. When using method-based callbacks, you now need to explicitly find the method using `ReflectionUtils` and provide it to the builder.
|
||||
|
||||
3. For non-static methods, you must now provide both the method and the target object:
|
||||
```java
|
||||
MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Description")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build()
|
||||
```
|
||||
|
||||
## Deprecated Methods
|
||||
|
||||
The following methods are deprecated and will be removed in a future release:
|
||||
|
||||
- `ChatClient.Builder.defaultFunctions(String...)`
|
||||
- `ChatClient.Builder.defaultFunctions(FunctionCallback...)`
|
||||
- `ChatClient.RequestSpec.functions()`
|
||||
|
||||
Use their `tools` counterparts instead.
|
||||
|
||||
## @Tool tool definition path.
|
||||
|
||||
Now you can use the method-level annothation (`@Tool`) to register tools with Spring AI
|
||||
|
||||
```java
|
||||
public class Home {
|
||||
|
||||
@Tool(description = "Turn light On or Off in a room.")
|
||||
public void turnLight(String roomName, boolean on) {
|
||||
// ...
|
||||
logger.info("Turn light in room: {} to: {}", roomName, on);
|
||||
}
|
||||
}
|
||||
|
||||
Home homeAutomation = new HomeAutomation();
|
||||
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("Turn the light in the living room On.")
|
||||
.tools(homeAutomation)
|
||||
.call()
|
||||
.content();
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Additional Notes
|
||||
|
||||
1. The new API provides better separation between tool definition and implementation.
|
||||
2. Tool definitions can be reused across different implementations.
|
||||
3. The builder pattern has been simplified for common use cases.
|
||||
4. Better support for method-based tools with improved error handling.
|
||||
|
||||
## Timeline
|
||||
|
||||
The deprecated methods will be maintained for backward compatibility in the current major version but will be removed in the next major release. It's recommended to migrate to the new API as soon as possible.
|
||||
@@ -47,8 +47,8 @@ import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.converter.MapOutputConverter;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
@@ -258,7 +258,7 @@ class AnthropicChatModelIT {
|
||||
List.of(new Media(new MimeType("application", "pdf"), pdfData)));
|
||||
|
||||
var response = this.chatModel.call(new Prompt(List.of(userMessage),
|
||||
FunctionCallingOptions.builder().model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()).build()));
|
||||
ToolCallingChatOptions.builder().model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()).build()));
|
||||
|
||||
assertThat(response.getResult().getOutput().getText()).containsAnyOf("Spring AI", "portable API");
|
||||
}
|
||||
@@ -273,8 +273,7 @@ class AnthropicChatModelIT {
|
||||
|
||||
var promptOptions = AnthropicChatOptions.builder()
|
||||
.model(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName())
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description(
|
||||
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
@@ -306,8 +305,7 @@ class AnthropicChatModelIT {
|
||||
|
||||
var promptOptions = AnthropicChatOptions.builder()
|
||||
.model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description(
|
||||
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
@@ -337,8 +335,7 @@ class AnthropicChatModelIT {
|
||||
|
||||
var promptOptions = AnthropicChatOptions.builder()
|
||||
.model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName())
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description(
|
||||
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
|
||||
@@ -40,7 +40,7 @@ import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -211,8 +211,7 @@ class AnthropicChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.call()
|
||||
@@ -230,8 +229,7 @@ class AnthropicChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeatherInLocation", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeatherInLocation", new MockWeatherService())
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.call()
|
||||
@@ -248,8 +246,7 @@ class AnthropicChatClientIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(this.chatModel)
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -271,8 +268,7 @@ class AnthropicChatClientIT {
|
||||
// @formatter:off
|
||||
Flux<String> response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
|
||||
@@ -29,11 +29,14 @@ import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ToolContext;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.tool.method.MethodToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
||||
@@ -41,6 +44,7 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
||||
@SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429")
|
||||
@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+")
|
||||
@ActiveProfiles("logging-test")
|
||||
@SuppressWarnings("null")
|
||||
class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
private static final LogAccessor logger = new LogAccessor(
|
||||
@@ -57,11 +61,14 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
void methodGetWeatherGeneratedDescription() {
|
||||
|
||||
// @formatter:off
|
||||
var toolMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "getWeatherInLocation", String.class, Unit.class);
|
||||
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherInLocation", String.class, Unit.class)
|
||||
.targetClass(TestFunctionClass.class)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod).build())
|
||||
.toolMethod(toolMethod)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -76,12 +83,16 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
void methodGetWeatherStatic() {
|
||||
|
||||
// @formatter:off
|
||||
var toolMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "getWeatherStatic", String.class, Unit.class);
|
||||
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherStatic", String.class, Unit.class)
|
||||
.description("Get the weather in location")
|
||||
.targetClass(TestFunctionClass.class)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -98,12 +109,18 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
// @formatter:off
|
||||
|
||||
var turnLightMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "turnLight", String.class, boolean.class);
|
||||
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("Turn light on in the living room.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("turnLight", String.class, boolean.class)
|
||||
.description("Turn light on in the living room.")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(turnLightMethod)
|
||||
.description("Turn light on in the living room.")
|
||||
.build())
|
||||
.toolMethod(turnLightMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -121,12 +138,17 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
// @formatter:off
|
||||
var toolMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "getWeatherNonStatic", String.class, Unit.class);
|
||||
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherNonStatic", String.class, Unit.class)
|
||||
.description("Get the weather in location")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -143,17 +165,21 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
// @formatter:off
|
||||
var toolMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "getWeatherWithContext", String.class, Unit.class, ToolContext.class);
|
||||
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherWithContext", String.class, Unit.class, ToolContext.class)
|
||||
.description("Get the weather in location")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.toolContext(Map.of("tool", "value"))
|
||||
.call()
|
||||
.content();
|
||||
// @formatter:on
|
||||
|
||||
logger.info("Response: " + response);
|
||||
|
||||
@@ -170,18 +196,23 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
// @formatter:off
|
||||
var toolMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "getWeatherNonStatic", String.class, Unit.class);
|
||||
|
||||
assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherNonStatic", String.class, Unit.class)
|
||||
.description("Get the weather in location")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.toolContext(Map.of("tool", "value"))
|
||||
.call()
|
||||
.content())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("Configured method does not accept ToolContext as input parameter!");
|
||||
.hasMessage("ToolContext is not supported by the method as an argument");
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
@@ -190,13 +221,18 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
// @formatter:off
|
||||
// @formatter:off
|
||||
var toolMethod = ReflectionUtils.findMethod(
|
||||
TestFunctionClass.class, "turnLivingRoomLightOn");
|
||||
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("Turn light on in the living room.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("turnLivingRoomLightOn")
|
||||
.description("Can turn lights on in the Living Room")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolMethod(toolMethod)
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Can turn lights on in the Living Room")
|
||||
.build())
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -207,6 +243,25 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
assertThat(arguments).containsEntry("turnLivingRoomLightOn", true);
|
||||
}
|
||||
|
||||
@Test
|
||||
void toolAnnotation() {
|
||||
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("Turn light red in the living room.")
|
||||
.tools(targetObject)
|
||||
.call()
|
||||
.content();
|
||||
// @formatter:on
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
assertThat(arguments).containsEntry("roomName", "living room")
|
||||
.containsEntry("color", TestFunctionClass.LightColor.RED);
|
||||
}
|
||||
|
||||
@Autowired
|
||||
ChatModel chatModel;
|
||||
|
||||
@@ -270,6 +325,19 @@ class AnthropicChatClientMethodInvokingFunctionCallbackIT {
|
||||
arguments.put("turnLivingRoomLightOn", true);
|
||||
}
|
||||
|
||||
enum LightColor {
|
||||
|
||||
RED, GREEN, BLUE
|
||||
|
||||
}
|
||||
|
||||
@Tool(description = "Change the lamp color in a room.")
|
||||
public void changeRoomLightColor(String roomName, LightColor color) {
|
||||
arguments.put("roomName", roomName);
|
||||
arguments.put("color", color);
|
||||
logger.info("Change light colur in room: {} to color: {}", roomName, color);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -69,8 +69,7 @@ class AzureOpenAiChatModelFunctionCallIT {
|
||||
|
||||
var promptOptions = AzureOpenAiChatOptions.builder()
|
||||
.deploymentName(this.selectedModel)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the current weather in a given location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
@@ -98,8 +97,7 @@ class AzureOpenAiChatModelFunctionCallIT {
|
||||
|
||||
var promptOptions = AzureOpenAiChatOptions.builder()
|
||||
.deploymentName(this.selectedModel)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the current weather in a given location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
@@ -120,8 +118,7 @@ class AzureOpenAiChatModelFunctionCallIT {
|
||||
|
||||
var promptOptions = AzureOpenAiChatOptions.builder()
|
||||
.deploymentName(this.selectedModel)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the current weather in a given location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
@@ -158,8 +155,7 @@ class AzureOpenAiChatModelFunctionCallIT {
|
||||
|
||||
var promptOptions = AzureOpenAiChatOptions.builder()
|
||||
.deploymentName(this.selectedModel)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the current weather in a given location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
@@ -185,8 +181,7 @@ class AzureOpenAiChatModelFunctionCallIT {
|
||||
|
||||
var promptOptions = AzureOpenAiChatOptions.builder()
|
||||
.deploymentName(this.selectedModel)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the current weather in a given location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
|
||||
@@ -34,8 +34,8 @@ import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -210,8 +210,7 @@ class BedrockConverseChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel)
|
||||
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -230,8 +229,7 @@ class BedrockConverseChatClientIT {
|
||||
// @formatter:off
|
||||
ChatResponse response = ChatClient.create(this.chatModel)
|
||||
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -265,8 +263,7 @@ class BedrockConverseChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel)
|
||||
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -285,8 +282,7 @@ class BedrockConverseChatClientIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(this.chatModel)
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -308,8 +304,7 @@ class BedrockConverseChatClientIT {
|
||||
// @formatter:off
|
||||
Flux<ChatResponse> response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -350,8 +345,7 @@ class BedrockConverseChatClientIT {
|
||||
// @formatter:off
|
||||
Flux<String> response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in Paris? Return the temperature in Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -371,7 +365,7 @@ class BedrockConverseChatClientIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.options(FunctionCallingOptions.builder().model(modelName).build())
|
||||
.options(ToolCallingChatOptions.builder().model(modelName).build())
|
||||
.user(u -> u.text("Explain what do you see on this picture?")
|
||||
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))
|
||||
.call()
|
||||
@@ -393,7 +387,7 @@ class BedrockConverseChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
// TODO consider adding model(...) method to ChatClient as a shortcut to
|
||||
.options(FunctionCallingOptions.builder().model(modelName).build())
|
||||
.options(ToolCallingChatOptions.builder().model(modelName).build())
|
||||
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
|
||||
.call()
|
||||
.content();
|
||||
|
||||
@@ -29,8 +29,8 @@ import org.springframework.ai.bedrock.converse.RequiresAwsCredentials;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -145,8 +145,7 @@ public class BedrockNovaChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", (WeatherRequest request) -> {
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", (WeatherRequest request) -> {
|
||||
if (request.location().contains("Paris")) {
|
||||
return new WeatherResponse(15, request.unit());
|
||||
}
|
||||
@@ -183,7 +182,7 @@ public class BedrockNovaChatClientIT {
|
||||
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.withRegion(Region.US_EAST_1)
|
||||
.withTimeout(Duration.ofSeconds(120))
|
||||
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
|
||||
.withDefaultOptions(ToolCallingChatOptions.builder().model(modelId).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -224,8 +224,7 @@ class MistralAiChatClientIT {
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).toolChoice(ToolChoice.AUTO).build())
|
||||
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius."))
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -246,8 +245,7 @@ class MistralAiChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(this.chatModel)
|
||||
.defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build())
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -270,8 +268,7 @@ class MistralAiChatClientIT {
|
||||
Flux<String> response = ChatClient.create(this.chatModel).prompt()
|
||||
.options(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL).build())
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use parallel function calling if required. Response should be in Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
|
||||
@@ -47,6 +47,7 @@ import org.springframework.ai.converter.MapOutputConverter;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -97,7 +98,8 @@ class MistralAiChatModelIT {
|
||||
"Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.");
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
|
||||
// NOTE: Mistral expects the system message to be before the user message or will
|
||||
// NOTE: Mistral expects the system message to be before the user message or
|
||||
// will
|
||||
// fail with 400 error.
|
||||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
|
||||
ChatResponse response = this.chatModel.call(prompt);
|
||||
@@ -202,8 +204,7 @@ class MistralAiChatModelIT {
|
||||
|
||||
var promptOptions = MistralAiChatOptions.builder()
|
||||
.model(MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
@@ -228,8 +229,7 @@ class MistralAiChatModelIT {
|
||||
|
||||
var promptOptions = MistralAiChatOptions.builder()
|
||||
.model(MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
@@ -316,8 +316,7 @@ class MistralAiChatModelIT {
|
||||
|
||||
var promptOptions = MistralAiChatOptions.builder()
|
||||
.model(MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
|
||||
@@ -30,10 +30,10 @@ import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.ollama.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -61,8 +61,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
|
||||
|
||||
var promptOptions = OllamaOptions.builder()
|
||||
.model(MODEL)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description(
|
||||
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
@@ -85,8 +84,7 @@ class OllamaChatModelFunctionCallingIT extends BaseOllamaIT {
|
||||
|
||||
var promptOptions = OllamaOptions.builder()
|
||||
.model(MODEL)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description(
|
||||
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
|
||||
@@ -35,13 +35,13 @@ import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.OpenAiTestConfiguration;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
|
||||
import org.springframework.ai.openai.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.openai.testutils.AbstractIT;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
@@ -245,16 +245,13 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
@Test
|
||||
void functionCallTest() {
|
||||
|
||||
FunctionCallback functionCallback = FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build();
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
|
||||
.functions(functionCallback)
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
// @formatter:on
|
||||
@@ -269,8 +266,7 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(this.chatModel)
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -290,8 +286,7 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
// @formatter:off
|
||||
Flux<String> response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
|
||||
@@ -26,12 +26,14 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ToolContext;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.openai.OpenAiTestConfiguration;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.tool.method.MethodToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
||||
@@ -55,13 +57,18 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
@Test
|
||||
void methodGetWeatherStatic() {
|
||||
|
||||
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherStatic", String.class,
|
||||
Unit.class);
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherStatic", String.class, Unit.class)
|
||||
.description("Get the weather in location")
|
||||
.targetClass(TestFunctionClass.class)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -77,13 +84,17 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLight", String.class, boolean.class);
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("Turn light on in the living room.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("turnLight", String.class, boolean.class)
|
||||
.description("Can turn lights on or off by room name")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Can turn lights on or off by room name")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -100,13 +111,18 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
|
||||
Unit.class);
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherNonStatic", String.class, Unit.class)
|
||||
.description("Get the weather in location")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
@@ -122,13 +138,18 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherWithContext", String.class,
|
||||
Unit.class, ToolContext.class);
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherWithContext", String.class, Unit.class, ToolContext.class)
|
||||
.description("Get the weather in location")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.toolContext(Map.of("tool", "value"))
|
||||
.call()
|
||||
@@ -146,19 +167,24 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class,
|
||||
Unit.class);
|
||||
|
||||
// @formatter:off
|
||||
assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("getWeatherNonStatic", String.class, Unit.class)
|
||||
.description("Get the weather in location")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Get the weather in location")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.toolContext(Map.of("tool", "value"))
|
||||
.call()
|
||||
.content())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("Configured method does not accept ToolContext as input parameter!");
|
||||
.hasMessage("ToolContext is not supported by the method as an argument");
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
@@ -167,13 +193,17 @@ class OpenAiChatClientMethodInvokingFunctionCallbackIT {
|
||||
|
||||
TestFunctionClass targetObject = new TestFunctionClass();
|
||||
|
||||
var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "turnLivingRoomLightOn");
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("Turn light on in the living room.")
|
||||
.functions(FunctionCallback.builder()
|
||||
.method("turnLivingRoomLightOn")
|
||||
.description("Can turn lights on in the Living Room")
|
||||
.targetObject(targetObject)
|
||||
.tools(MethodToolCallback.builder()
|
||||
.toolDefinition(ToolDefinition.builder(toolMethod)
|
||||
.description("Can turn lights on in the Living Room")
|
||||
.build())
|
||||
.toolMethod(toolMethod)
|
||||
.toolObject(targetObject)
|
||||
.build())
|
||||
.call()
|
||||
.content();
|
||||
|
||||
@@ -29,12 +29,12 @@ import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ToolContext;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.openai.OpenAiTestConfiguration;
|
||||
import org.springframework.ai.openai.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.openai.api.tool.MockWeatherService.Request;
|
||||
import org.springframework.ai.openai.api.tool.MockWeatherService.Response;
|
||||
import org.springframework.ai.openai.testutils.AbstractIT;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.io.Resource;
|
||||
@@ -83,8 +83,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT {
|
||||
// @formatter:off
|
||||
response = chatClientBuilder.build().prompt()
|
||||
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -114,8 +113,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(this.chatModel)
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -157,8 +155,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(this.chatModel)
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", biFunction)
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", biFunction)
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -201,8 +198,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(this.chatModel)
|
||||
.defaultFunctions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", biFunction)
|
||||
.defaultTools(FunctionToolCallback.builder("getCurrentWeather", biFunction)
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -224,8 +220,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT {
|
||||
// @formatter:off
|
||||
Flux<String> response = ChatClient.create(this.chatModel).prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.tools(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -254,8 +249,7 @@ class OpenAiChatClientMultipleFunctionCallsIT extends AbstractIT {
|
||||
|
||||
String content = chatClient.prompt()
|
||||
.user("What's the weather like in Shanghai?")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("currentTemp", function)
|
||||
.tools(FunctionToolCallback.builder("currentTemp", function)
|
||||
.description("get current temp")
|
||||
.inputType(MyFunction.Req.class)
|
||||
.build())
|
||||
|
||||
@@ -122,7 +122,7 @@ class OpenAiChatClientProxyFunctionCallsIT extends AbstractIT {
|
||||
|
||||
chatResponse = chatClient.prompt()
|
||||
.messages(messages)
|
||||
.functions(this.functionDefinition)
|
||||
.tools(this.functionDefinition)
|
||||
.options(OpenAiChatOptions.builder().proxyToolCalls(true).build())
|
||||
.call()
|
||||
.chatResponse();
|
||||
|
||||
@@ -33,8 +33,8 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallback.SchemaType;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.ai.util.json.JsonSchemaGenerator;
|
||||
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
|
||||
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -82,10 +82,9 @@ public class VertexAiGeminiChatModelFunctionCallingIT {
|
||||
""";
|
||||
|
||||
var promptOptions = VertexAiGeminiChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("get_current_weather", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("get_current_weather", new MockWeatherService())
|
||||
.description("Get the current weather in a given location")
|
||||
.inputTypeSchema(openApiSchema)
|
||||
.inputSchema(openApiSchema)
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
@@ -100,46 +99,6 @@ public class VertexAiGeminiChatModelFunctionCallingIT {
|
||||
@Test
|
||||
public void functionCallTestInferredOpenApiSchema() {
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius units.");
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = VertexAiGeminiChatOptions.builder()
|
||||
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
|
||||
.functionCallbacks(List.of(
|
||||
FunctionCallback.builder()
|
||||
.function("get_current_weather", new MockWeatherService())
|
||||
.schemaType(SchemaType.OPEN_API_SCHEMA)
|
||||
.description("Get the current weather in a given location.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build(),
|
||||
FunctionCallback.builder()
|
||||
.function("get_payment_status", new PaymentStatus())
|
||||
.schemaType(SchemaType.OPEN_API_SCHEMA)
|
||||
.description(
|
||||
"Retrieves the payment status for transaction. For example what is the payment status for transaction 700?")
|
||||
.inputType(PaymentInfoRequest.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
|
||||
|
||||
logger.info("Response: " + response);
|
||||
|
||||
assertThat(response.getResult().getOutput().getText()).containsAnyOf("15.0", "15");
|
||||
|
||||
ChatResponse response2 = this.chatModel
|
||||
.call(new Prompt("What is the payment status for transaction 696?", promptOptions));
|
||||
|
||||
logger.info("Response: " + response2);
|
||||
|
||||
assertThat(response2.getResult().getOutput().getText()).containsIgnoringCase("transaction 696 is PAYED");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void functionCallTestInferredOpenApiSchema2() {
|
||||
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius.");
|
||||
|
||||
@@ -148,15 +107,15 @@ public class VertexAiGeminiChatModelFunctionCallingIT {
|
||||
var promptOptions = VertexAiGeminiChatOptions.builder()
|
||||
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
|
||||
.functionCallbacks(List.of(
|
||||
FunctionCallback.builder()
|
||||
.function("get_current_weather", new MockWeatherService())
|
||||
.schemaType(SchemaType.OPEN_API_SCHEMA)
|
||||
FunctionToolCallback.builder("get_current_weather", new MockWeatherService())
|
||||
.inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class,
|
||||
JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES))
|
||||
.description("Get the current weather in a given location.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build(),
|
||||
FunctionCallback.builder()
|
||||
.function("get_payment_status", new PaymentStatus())
|
||||
.schemaType(SchemaType.OPEN_API_SCHEMA)
|
||||
FunctionToolCallback.builder("get_payment_status", new PaymentStatus())
|
||||
.inputSchema(JsonSchemaGenerator.generateForType(PaymentInfoRequest.class,
|
||||
JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES))
|
||||
.description(
|
||||
"Retrieves the payment status for transaction. For example what is the payment status for transaction 700?")
|
||||
.inputType(PaymentInfoRequest.class)
|
||||
@@ -188,9 +147,9 @@ public class VertexAiGeminiChatModelFunctionCallingIT {
|
||||
|
||||
var promptOptions = VertexAiGeminiChatOptions.builder()
|
||||
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.schemaType(SchemaType.OPEN_API_SCHEMA)
|
||||
.functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.inputSchema(JsonSchemaGenerator.generateForType(MockWeatherService.Request.class,
|
||||
JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES))
|
||||
.description("Get the current weather in a given location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
|
||||
@@ -69,7 +69,7 @@ public class VertexAiGeminiPaymentTransactionIT {
|
||||
// @formatter:off
|
||||
String content = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.functions("paymentStatus")
|
||||
.tools("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
If requred invoke the function per transaction.
|
||||
@@ -86,8 +86,7 @@ public class VertexAiGeminiPaymentTransactionIT {
|
||||
|
||||
Flux<String> streamContent = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.functions("paymentStatus")
|
||||
// .functions("paymentStatuses")
|
||||
.tools("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
If requred invoke the function per transaction.
|
||||
|
||||
@@ -218,10 +218,12 @@ public interface ChatClient {
|
||||
|
||||
ChatClientRequestSpec tools(Object... toolObjects);
|
||||
|
||||
ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);
|
||||
// ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks);
|
||||
|
||||
@Deprecated
|
||||
<I, O> ChatClientRequestSpec functions(FunctionCallback... functionCallbacks);
|
||||
|
||||
@Deprecated
|
||||
ChatClientRequestSpec functions(String... functionBeanNames);
|
||||
|
||||
ChatClientRequestSpec toolContext(Map<String, Object> toolContext);
|
||||
@@ -281,10 +283,17 @@ public interface ChatClient {
|
||||
|
||||
Builder defaultTools(Object... toolObjects);
|
||||
|
||||
Builder defaultToolCallbacks(FunctionCallback... toolCallbacks);
|
||||
|
||||
/**
|
||||
* @deprecated in favor of {@link #defaultTools(String...)}
|
||||
*/
|
||||
@Deprecated
|
||||
Builder defaultFunctions(String... functionNames);
|
||||
|
||||
/**
|
||||
* @deprecated in favor of {@link #defaultTools(FunctionCallback...)} or
|
||||
* {@link #defaultToolCallbacks(FunctionCallback...)}
|
||||
*/
|
||||
@Deprecated
|
||||
Builder defaultFunctions(FunctionCallback... functionCallbacks);
|
||||
|
||||
Builder defaultToolContext(Map<String, Object> toolContext);
|
||||
|
||||
@@ -849,22 +849,38 @@ public class DefaultChatClient implements ChatClient {
|
||||
public ChatClientRequestSpec tools(Object... toolObjects) {
|
||||
Assert.notNull(toolObjects, "toolObjects cannot be null");
|
||||
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
|
||||
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
|
||||
|
||||
List<FunctionCallback> functionCallbacks = new ArrayList<>();
|
||||
List<Object> nonFunctinCallbacks = new ArrayList<>();
|
||||
for (Object toolObject : toolObjects) {
|
||||
if (toolObject instanceof FunctionCallback) {
|
||||
functionCallbacks.add((FunctionCallback) toolObject);
|
||||
}
|
||||
else {
|
||||
nonFunctinCallbacks.add(toolObject);
|
||||
}
|
||||
}
|
||||
this.functionCallbacks.addAll(functionCallbacks);
|
||||
this.functionCallbacks.addAll(Arrays
|
||||
.asList(ToolCallbacks.from(nonFunctinCallbacks.toArray(new Object[nonFunctinCallbacks.size()]))));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) {
|
||||
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
|
||||
this.functionCallbacks.addAll(Arrays.asList(toolCallbacks));
|
||||
return this;
|
||||
}
|
||||
// @Override
|
||||
// public ChatClientRequestSpec toolCallbacks(FunctionCallback... toolCallbacks) {
|
||||
// Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
// Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null
|
||||
// elements");
|
||||
// this.functionCallbacks.addAll(Arrays.asList(toolCallbacks));
|
||||
// return this;
|
||||
// }
|
||||
|
||||
@Deprecated
|
||||
public ChatClientRequestSpec functions(String... functionBeanNames) {
|
||||
return tools(functionBeanNames);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
|
||||
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
|
||||
Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements");
|
||||
|
||||
@@ -157,13 +157,7 @@ public class DefaultChatClientBuilder implements Builder {
|
||||
|
||||
@Override
|
||||
public Builder defaultTools(Object... toolObjects) {
|
||||
this.defaultRequest.functions(ToolCallbacks.from(toolObjects));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Builder defaultToolCallbacks(FunctionCallback... toolCallbacks) {
|
||||
this.defaultRequest.functions(toolCallbacks);
|
||||
this.defaultRequest.tools(toolObjects);
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -201,7 +195,7 @@ public class DefaultChatClientBuilder implements Builder {
|
||||
|
||||
void addToolCallbacks(List<FunctionCallback> toolCallbacks) {
|
||||
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
this.defaultRequest.toolCallbacks(toolCallbacks.toArray(FunctionCallback[]::new));
|
||||
this.defaultRequest.tools(toolCallbacks.toArray(FunctionCallback[]::new));
|
||||
}
|
||||
|
||||
void addToolContext(Map<String, Object> toolContext) {
|
||||
|
||||
@@ -52,14 +52,20 @@ public interface ToolDefinition {
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a default {@link ToolDefinition} instance from a {@link Method}.
|
||||
* Create a default {@link ToolDefinition} builder from a {@link Method}.
|
||||
*/
|
||||
static ToolDefinition from(Method method) {
|
||||
static DefaultToolDefinition.Builder builder(Method method) {
|
||||
return DefaultToolDefinition.builder()
|
||||
.name(ToolUtils.getToolName(method))
|
||||
.description(ToolUtils.getToolDescription(method))
|
||||
.inputSchema(JsonSchemaGenerator.generateForMethodInput(method))
|
||||
.build();
|
||||
.inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a default {@link ToolDefinition} instance from a {@link Method}.
|
||||
*/
|
||||
static ToolDefinition from(Method method) {
|
||||
return ToolDefinition.builder(method).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ClassUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.ReflectionUtils;
|
||||
|
||||
/**
|
||||
* A {@link ToolCallback} implementation to invoke methods as tools.
|
||||
@@ -64,10 +65,11 @@ public class MethodToolCallback implements ToolCallback {
|
||||
private final ToolCallResultConverter toolCallResultConverter;
|
||||
|
||||
public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
|
||||
Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
|
||||
@Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
|
||||
Assert.notNull(toolDefinition, "toolDefinition cannot be null");
|
||||
Assert.notNull(toolMethod, "toolMethod cannot be null");
|
||||
Assert.notNull(toolObject, "toolObject cannot be null");
|
||||
Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
|
||||
"toolObject cannot be null for non-static methods");
|
||||
this.toolDefinition = toolDefinition;
|
||||
this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
|
||||
this.toolMethod = toolMethod;
|
||||
@@ -165,7 +167,7 @@ public class MethodToolCallback implements ToolCallback {
|
||||
}
|
||||
|
||||
private boolean isObjectNotPublic() {
|
||||
return !Modifier.isPublic(toolObject.getClass().getModifiers());
|
||||
return toolObject != null && !Modifier.isPublic(toolObject.getClass().getModifiers());
|
||||
}
|
||||
|
||||
private boolean isMethodNotPublic() {
|
||||
|
||||
@@ -1373,9 +1373,9 @@ class DefaultChatClientTests {
|
||||
void whenToolCallbacksElementIsNullThenThrow() {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
ChatClient.ChatClientRequestSpec spec = chatClient.prompt();
|
||||
assertThatThrownBy(() -> spec.toolCallbacks(mock(ToolCallback.class), null))
|
||||
assertThatThrownBy(() -> spec.tools(mock(ToolCallback.class), null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("toolCallbacks cannot contain null elements");
|
||||
.hasMessage("toolObjects cannot contain null elements");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -1383,7 +1383,7 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
ChatClient.ChatClientRequestSpec spec = chatClient.prompt();
|
||||
ToolCallback toolCallback = mock(ToolCallback.class);
|
||||
spec = spec.toolCallbacks(toolCallback);
|
||||
spec = spec.tools(toolCallback);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec defaultSpec = (DefaultChatClient.DefaultChatClientRequestSpec) spec;
|
||||
assertThat(defaultSpec.getFunctionCallbacks()).contains(toolCallback);
|
||||
}
|
||||
|
||||
@@ -55,15 +55,17 @@ class MethodToolCallbackProviderTests {
|
||||
|
||||
assertThat(callbacks).hasSize(2);
|
||||
|
||||
var callback1 = Stream.of(callbacks).filter(c -> c.getName().equals("testMethod")).findFirst();
|
||||
var callback1 = Stream.of(callbacks).filter(c -> c.getToolDefinition().name().equals("testMethod")).findFirst();
|
||||
assertThat(callback1).isPresent();
|
||||
assertThat(callback1.get().getName()).isEqualTo("testMethod");
|
||||
assertThat(callback1.get().getDescription()).isEqualTo("Test description");
|
||||
assertThat(callback1.get().getToolDefinition().name()).isEqualTo("testMethod");
|
||||
assertThat(callback1.get().getToolDefinition().description()).isEqualTo("Test description");
|
||||
|
||||
var callback2 = Stream.of(callbacks).filter(c -> c.getName().equals("testStaticMethod")).findFirst();
|
||||
var callback2 = Stream.of(callbacks)
|
||||
.filter(c -> c.getToolDefinition().name().equals("testStaticMethod"))
|
||||
.findFirst();
|
||||
assertThat(callback2).isPresent();
|
||||
assertThat(callback2.get().getName()).isEqualTo("testStaticMethod");
|
||||
assertThat(callback2.get().getDescription()).isEqualTo("Test description");
|
||||
assertThat(callback2.get().getToolDefinition().name()).isEqualTo("testStaticMethod");
|
||||
assertThat(callback2.get().getToolDefinition().description()).isEqualTo("Test description");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -76,7 +76,7 @@ public class FunctionToolCallbackTests {
|
||||
.build()
|
||||
.prompt()
|
||||
.user("Welcome the users to the library")
|
||||
.toolCallbacks(FunctionToolCallback.builder("sayWelcome", (input) -> {
|
||||
.tools(FunctionToolCallback.builder("sayWelcome", (input) -> {
|
||||
logger.info("CALLBACK - Welcoming users to the library");
|
||||
})
|
||||
.description("Welcome users to the library")
|
||||
@@ -105,8 +105,8 @@ public class FunctionToolCallbackTests {
|
||||
.build()
|
||||
.prompt()
|
||||
.user("Welcome %s to the library".formatted("James Bond"))
|
||||
.toolCallbacks(FunctionToolCallback.builder("welcomeUser", (user) -> {
|
||||
logger.info("CALLBACK - Welcoming "+ ((User) user).name() +" to the library");
|
||||
.tools(FunctionToolCallback.builder("welcomeUser", (user) -> {
|
||||
logger.info("CALLBACK - Welcoming " + ((User) user).name() + " to the library");
|
||||
})
|
||||
.description("Welcome a specific user to the library")
|
||||
.inputType(User.class)
|
||||
@@ -141,7 +141,7 @@ public class FunctionToolCallbackTests {
|
||||
.build()
|
||||
.prompt()
|
||||
.user("What books written by %s are available in the library?".formatted("J.R.R. Tolkien"))
|
||||
.toolCallbacks(FunctionToolCallback.builder("availableBooksByAuthor", function)
|
||||
.tools(FunctionToolCallback.builder("availableBooksByAuthor", function)
|
||||
.description("Get the list of books written by the given author available in the library")
|
||||
.inputType(Author.class)
|
||||
.build())
|
||||
@@ -175,7 +175,7 @@ public class FunctionToolCallbackTests {
|
||||
.build()
|
||||
.prompt()
|
||||
.user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia"))
|
||||
.toolCallbacks(FunctionToolCallback.builder("authorsByAvailableBooks", function)
|
||||
.tools(FunctionToolCallback.builder("authorsByAvailableBooks", function)
|
||||
.description("Get the list of authors who wrote the given books available in the library")
|
||||
.inputType(Books.class)
|
||||
.build())
|
||||
|
||||
@@ -86,9 +86,9 @@ public class MethodToolCallbackTests {
|
||||
.call()
|
||||
.content();
|
||||
assertThat(content).isNotEmpty()
|
||||
.contains("The Hobbit")
|
||||
.contains("The Lord of The Rings")
|
||||
.contains("The Silmarillion");
|
||||
.containsIgnoringCase("The Hobbit")
|
||||
.containsIgnoringCase("The Lord of The Rings")
|
||||
.containsIgnoringCase("The Silmarillion");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -109,7 +109,7 @@ public class MethodToolCallbackTests {
|
||||
.build()
|
||||
.prompt()
|
||||
.user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia"))
|
||||
.toolCallbacks(ToolCallbacks.from(tools))
|
||||
.tools(ToolCallbacks.from(tools))
|
||||
.call()
|
||||
.content();
|
||||
assertThat(content).isNotEmpty().contains("J.R.R. Tolkien").contains("C.S. Lewis");
|
||||
|
||||
@@ -28,7 +28,7 @@ import org.springframework.ai.autoconfigure.anthropic.AnthropicAutoConfiguration
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -57,11 +57,11 @@ public class FunctionCallWithPromptFunctionIT {
|
||||
"What's the weather like in San Francisco, in Paris and in Tokyo? Return the temperature in Celsius.");
|
||||
|
||||
var promptOptions = AnthropicChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.functionCallbacks(
|
||||
List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions));
|
||||
|
||||
@@ -29,7 +29,7 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -95,7 +95,7 @@ class FunctionCallWithFunctionBeanIT {
|
||||
"What's the weather like in San Francisco, Paris and in Tokyo? Use Multi-turn function calling.");
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
|
||||
FunctionCallingOptions.builder().function("weatherFunction").build()));
|
||||
ToolCallingChatOptions.builder().tools("weatherFunction").build()));
|
||||
|
||||
logger.info("Response: " + response);
|
||||
|
||||
|
||||
@@ -27,7 +27,8 @@ import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -76,10 +77,9 @@ public class FunctionCallWithFunctionWrapperIT {
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public FunctionCallback weatherFunctionInfo() {
|
||||
public ToolCallback weatherFunctionInfo() {
|
||||
|
||||
return FunctionCallback.builder()
|
||||
.function("WeatherInfo", new MockWeatherService())
|
||||
return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService())
|
||||
.description("Get the current weather in a given location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build();
|
||||
|
||||
@@ -27,7 +27,7 @@ import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -60,11 +60,11 @@ public class FunctionCallWithPromptFunctionIT {
|
||||
"What's the weather like in San Francisco, in Paris and in Tokyo? Use Multi-turn function calling.");
|
||||
|
||||
var promptOptions = AzureOpenAiChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.functionCallbacks(
|
||||
List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions));
|
||||
|
||||
@@ -30,7 +30,7 @@ import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -63,14 +63,14 @@ class FunctionCallWithFunctionBeanIT {
|
||||
"What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius.");
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
|
||||
FunctionCallingOptions.builder().function("weatherFunction").build()));
|
||||
ToolCallingChatOptions.builder().tools("weatherFunction").build()));
|
||||
|
||||
logger.info("Response: " + response);
|
||||
|
||||
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
|
||||
|
||||
response = chatModel.call(new Prompt(List.of(userMessage),
|
||||
FunctionCallingOptions.builder().function("weatherFunction3").build()));
|
||||
ToolCallingChatOptions.builder().tools("weatherFunction3").build()));
|
||||
|
||||
logger.info("Response: " + response);
|
||||
|
||||
@@ -92,7 +92,7 @@ class FunctionCallWithFunctionBeanIT {
|
||||
"What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius.");
|
||||
|
||||
Flux<ChatResponse> responses = chatModel.stream(new Prompt(List.of(userMessage),
|
||||
FunctionCallingOptions.builder().function("weatherFunction").build()));
|
||||
ToolCallingChatOptions.builder().tools("weatherFunction").build()));
|
||||
|
||||
String content = responses.collectList()
|
||||
.block()
|
||||
|
||||
@@ -27,8 +27,8 @@ import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -55,12 +55,12 @@ public class FunctionCallWithPromptFunctionIT {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What's the weather like in San Francisco, in Paris and in Tokyo? Return the temperature in Celsius.");
|
||||
|
||||
var promptOptions = FunctionCallingOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
var promptOptions = ToolCallingChatOptions.builder()
|
||||
.toolCallbacks(
|
||||
List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions));
|
||||
|
||||
@@ -30,7 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.mistralai.MistralAiChatModel;
|
||||
import org.springframework.ai.mistralai.MistralAiChatOptions;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -63,8 +63,8 @@ public class PaymentStatusPromptIT {
|
||||
UserMessage userMessage = new UserMessage("What's the status of my transaction with id T1001?");
|
||||
|
||||
var promptOptions = MistralAiChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("retrievePaymentStatus",
|
||||
.functionCallbacks(List.of(FunctionToolCallback
|
||||
.builder("retrievePaymentStatus",
|
||||
(Transaction transaction) -> new Status(DATA.get(transaction).status()))
|
||||
.description("Get payment status of a transaction")
|
||||
.inputType(Transaction.class)
|
||||
|
||||
@@ -35,8 +35,8 @@ import org.springframework.ai.mistralai.MistralAiChatModel;
|
||||
import org.springframework.ai.mistralai.MistralAiChatOptions;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -72,11 +72,11 @@ public class WeatherServicePromptIT {
|
||||
|
||||
var promptOptions = MistralAiChatOptions.builder()
|
||||
.toolChoice(ToolChoice.AUTO)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MyWeatherService())
|
||||
.description("Get the current weather in requested location")
|
||||
.inputType(MyWeatherService.Request.class)
|
||||
.build()))
|
||||
.functionCallbacks(
|
||||
List.of(FunctionToolCallback.builder("CurrentWeatherService", new MyWeatherService())
|
||||
.description("Get the current weather in requested location")
|
||||
.inputType(MyWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions));
|
||||
@@ -97,9 +97,8 @@ public class WeatherServicePromptIT {
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in Paris? Use Celsius.");
|
||||
|
||||
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MyWeatherService())
|
||||
ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder()
|
||||
.toolCallbacks(List.of(FunctionToolCallback.builder("CurrentWeatherService", new MyWeatherService())
|
||||
.description("Get the current weather in requested location")
|
||||
.inputType(MyWeatherService.Request.class)
|
||||
.build()))
|
||||
|
||||
@@ -30,9 +30,9 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -69,8 +69,8 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT {
|
||||
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
|
||||
|
||||
var promptOptions = OllamaOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback
|
||||
.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description(
|
||||
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
@@ -95,8 +95,8 @@ public class FunctionCallbackInPromptIT extends BaseOllamaIT {
|
||||
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
|
||||
|
||||
var promptOptions = OllamaOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.functionCallbacks(List.of(FunctionToolCallback
|
||||
.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description(
|
||||
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
|
||||
@@ -30,10 +30,11 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -117,7 +118,7 @@ public class OllamaFunctionCallbackIT extends BaseOllamaIT {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What are the weather conditions in San Francisco, Tokyo, and Paris? Find the temperature in Celsius for each of the three locations.");
|
||||
|
||||
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder().function("WeatherInfo").build();
|
||||
ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder().tools("WeatherInfo").build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
|
||||
|
||||
@@ -131,10 +132,9 @@ public class OllamaFunctionCallbackIT extends BaseOllamaIT {
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public FunctionCallback weatherFunctionInfo() {
|
||||
public ToolCallback weatherFunctionInfo() {
|
||||
|
||||
return FunctionCallback.builder()
|
||||
.function("WeatherInfo", new MockWeatherService())
|
||||
return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService())
|
||||
.description(
|
||||
"Find the weather conditions, forecasts, and temperatures for a location, like a city or state.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
|
||||
@@ -25,9 +25,9 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -58,12 +58,12 @@ public class FunctionCallbackInPrompt2IT {
|
||||
.call().content();
|
||||
|
||||
String content = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
|
||||
.functions(FunctionToolCallback
|
||||
.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
.call().content();
|
||||
// @formatter:on
|
||||
|
||||
@@ -87,8 +87,8 @@ public class FunctionCallbackInPrompt2IT {
|
||||
// @formatter:off
|
||||
String content = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("Turn the light on in the kitchen and in the living room!")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("turnLight", (LightInfo lightInfo) -> {
|
||||
.functions(FunctionToolCallback
|
||||
.builder("turnLight", (LightInfo lightInfo) -> {
|
||||
logger.info("Turning light to [" + lightInfo.isOn + "] in " + lightInfo.roomName());
|
||||
state.put(lightInfo.roomName(), lightInfo.isOn());
|
||||
})
|
||||
@@ -113,8 +113,8 @@ public class FunctionCallbackInPrompt2IT {
|
||||
// @formatter:off
|
||||
String content = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("What's the weather like in Amsterdam?")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", input -> "18 degrees Celsius")
|
||||
.functions(FunctionToolCallback
|
||||
.builder("CurrentWeatherService", input -> "18 degrees Celsius")
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
@@ -137,8 +137,8 @@ public class FunctionCallbackInPrompt2IT {
|
||||
// @formatter:off
|
||||
String content = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
|
||||
.functions(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.functions(FunctionToolCallback
|
||||
.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build())
|
||||
|
||||
@@ -29,10 +29,10 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.core.log.LogAccessor;
|
||||
@@ -61,11 +61,11 @@ public class FunctionCallbackInPromptIT {
|
||||
"What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.functionCallbacks(
|
||||
List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), promptOptions));
|
||||
@@ -90,11 +90,11 @@ public class FunctionCallbackInPromptIT {
|
||||
"What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.functionCallbacks(
|
||||
List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
Flux<ChatResponse> response = chatModel.stream(new Prompt(List.of(userMessage), promptOptions));
|
||||
|
||||
@@ -38,7 +38,7 @@ import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.model.ToolContext;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
|
||||
@@ -154,9 +154,7 @@ class FunctionCallbackWithPlainFunctionBeanIT {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"Please schedule a train from San Francisco to Los Angeles on 2023-12-25");
|
||||
|
||||
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
|
||||
.function("trainReservation")
|
||||
.build();
|
||||
ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder().tools("trainReservation").build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
|
||||
|
||||
@@ -265,9 +263,7 @@ class FunctionCallbackWithPlainFunctionBeanIT {
|
||||
// Test weatherFunction
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
|
||||
.function("weatherFunction")
|
||||
.build();
|
||||
ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder().tools("weatherFunction").build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
|
||||
|
||||
|
||||
@@ -23,9 +23,10 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -92,10 +93,9 @@ public class OpenAiFunctionCallback2IT {
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public FunctionCallback weatherFunctionInfo() {
|
||||
public ToolCallback weatherFunctionInfo() {
|
||||
|
||||
return FunctionCallback.builder()
|
||||
.function("WeatherInfo", new MockWeatherService())
|
||||
return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build();
|
||||
|
||||
@@ -29,10 +29,11 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -103,10 +104,9 @@ public class OpenAiFunctionCallbackIT {
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public FunctionCallback weatherFunctionInfo() {
|
||||
public ToolCallback weatherFunctionInfo() {
|
||||
|
||||
return FunctionCallback.builder()
|
||||
.function("WeatherInfo", new MockWeatherService())
|
||||
return FunctionToolCallback.builder("WeatherInfo", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build();
|
||||
|
||||
Reference in New Issue
Block a user