test: Add unit test to verify multiple method toolcallbacks with toolcontext

Auto-cherry-pick to 1.0.x

Signed-off-by: Sun Yuhan <sunyuhan1998@users.noreply.github.com>
This commit is contained in:
Sun Yuhan
2025-06-12 18:13:25 +08:00
committed by Ilayaperumal Gopinathan
parent 16a7084faa
commit b9a683460d

View File

@@ -16,6 +16,7 @@
package org.springframework.ai.model.tool;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
@@ -27,6 +28,7 @@ import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.DefaultToolDefinition;
@@ -34,6 +36,7 @@ import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.metadata.ToolMetadata;
import org.springframework.ai.tool.method.MethodToolCallback;
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
@@ -45,6 +48,7 @@ import static org.mockito.Mockito.mock;
* Unit tests for {@link DefaultToolCallingManager}.
*
* @author Thomas Vitale
* @author Sun Yuhan
*/
class DefaultToolCallingManagerTests {
@@ -317,6 +321,49 @@ class DefaultToolCallingManagerTests {
assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse);
}
@Test
void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodException {
ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build();
ToolDefinition toolDefinitionA = ToolDefinition.builder().name("toolA").inputSchema("{}").build();
Method methodA = TestGenericClass.class.getMethod("call", String.class);
MethodToolCallback methodToolCallback = MethodToolCallback.builder()
.toolDefinition(toolDefinitionA)
.toolMethod(methodA)
.toolObject(new TestGenericClass())
.build();
ToolDefinition toolDefinitionB = ToolDefinition.builder().name("toolB").inputSchema("{}").build();
Method methodB = TestGenericClass.class.getMethod("callWithToolContext", ToolContext.class);
MethodToolCallback methodToolCallbackNeedToolContext = MethodToolCallback.builder()
.toolDefinition(toolDefinitionB)
.toolMethod(methodB)
.toolObject(new TestGenericClass())
.build();
Prompt prompt = new Prompt(new UserMessage("Hello"),
ToolCallingChatOptions.builder()
.toolCallbacks(methodToolCallback, methodToolCallbackNeedToolContext)
.toolNames("toolA", "toolB")
.toolContext("key", "value")
.build());
ChatResponse chatResponse = ChatResponse.builder()
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))))))
.build();
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON),
new ToolResponseMessage.ToolResponse("toolB", "toolB",
TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON)));
ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);
assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse);
}
static class TestToolCallback implements ToolCallback {
private final ToolDefinition toolDefinition;
@@ -370,4 +417,31 @@ class DefaultToolCallingManagerTests {
}
/**
* Test class with methods that use generic types.
*/
static class TestGenericClass {
public final static String CALL_RESULT_JSON = """
{
"result": "Mission accomplished!"
}
""";
public final static String CALL_WITH_TOOL_CONTEXT_RESULT_JSON = """
{
"result": "ToolContext mission accomplished!"
}
""";
public String call(String toolInput) {
return CALL_RESULT_JSON;
}
public String callWithToolContext(ToolContext toolContext) {
return CALL_WITH_TOOL_CONTEXT_RESULT_JSON;
}
}
}