diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java index 7dba4ad25..c25e221eb 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.model.tool; +import java.lang.reflect.Method; import java.util.List; import java.util.Map; @@ -27,6 +28,7 @@ import org.springframework.ai.chat.messages.ToolResponseMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.definition.DefaultToolDefinition; @@ -34,6 +36,7 @@ import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor; import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; @@ -45,6 +48,7 @@ import static org.mockito.Mockito.mock; * Unit tests for {@link DefaultToolCallingManager}. * * @author Thomas Vitale + * @author Sun Yuhan */ class DefaultToolCallingManagerTests { @@ -317,6 +321,49 @@ class DefaultToolCallingManagerTests { assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); } + @Test + void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodException { + ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build(); + + ToolDefinition toolDefinitionA = ToolDefinition.builder().name("toolA").inputSchema("{}").build(); + Method methodA = TestGenericClass.class.getMethod("call", String.class); + MethodToolCallback methodToolCallback = MethodToolCallback.builder() + .toolDefinition(toolDefinitionA) + .toolMethod(methodA) + .toolObject(new TestGenericClass()) + .build(); + + ToolDefinition toolDefinitionB = ToolDefinition.builder().name("toolB").inputSchema("{}").build(); + Method methodB = TestGenericClass.class.getMethod("callWithToolContext", ToolContext.class); + MethodToolCallback methodToolCallbackNeedToolContext = MethodToolCallback.builder() + .toolDefinition(toolDefinitionB) + .toolMethod(methodB) + .toolObject(new TestGenericClass()) + .build(); + + Prompt prompt = new Prompt(new UserMessage("Hello"), + ToolCallingChatOptions.builder() + .toolCallbacks(methodToolCallback, methodToolCallbackNeedToolContext) + .toolNames("toolA", "toolB") + .toolContext("key", "value") + .build()); + + ChatResponse chatResponse = ChatResponse.builder() + .generations(List.of(new Generation(new AssistantMessage("", Map.of(), + List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"), + new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}")))))) + .build(); + + ToolResponseMessage expectedToolResponse = new ToolResponseMessage( + List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON), + new ToolResponseMessage.ToolResponse("toolB", "toolB", + TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON))); + + ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); + + assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse); + } + static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition; @@ -370,4 +417,31 @@ class DefaultToolCallingManagerTests { } + /** + * Test class with methods that use generic types. + */ + static class TestGenericClass { + + public final static String CALL_RESULT_JSON = """ + { + "result": "Mission accomplished!" + } + """; + + public final static String CALL_WITH_TOOL_CONTEXT_RESULT_JSON = """ + { + "result": "ToolContext mission accomplished!" + } + """; + + public String call(String toolInput) { + return CALL_RESULT_JSON; + } + + public String callWithToolContext(ToolContext toolContext) { + return CALL_WITH_TOOL_CONTEXT_RESULT_JSON; + } + + } + }