test: Add unit test to verify multiple method toolcallbacks with toolcontext
Auto-cherry-pick to 1.0.x Signed-off-by: Sun Yuhan <sunyuhan1998@users.noreply.github.com>
This commit is contained in:
committed by
Ilayaperumal Gopinathan
parent
16a7084faa
commit
b9a683460d
@@ -16,6 +16,7 @@
|
||||
|
||||
package org.springframework.ai.model.tool;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -27,6 +28,7 @@ import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.model.ToolContext;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.definition.DefaultToolDefinition;
|
||||
@@ -34,6 +36,7 @@ import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.tool.execution.ToolExecutionException;
|
||||
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
|
||||
import org.springframework.ai.tool.metadata.ToolMetadata;
|
||||
import org.springframework.ai.tool.method.MethodToolCallback;
|
||||
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
|
||||
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
|
||||
|
||||
@@ -45,6 +48,7 @@ import static org.mockito.Mockito.mock;
|
||||
* Unit tests for {@link DefaultToolCallingManager}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @author Sun Yuhan
|
||||
*/
|
||||
class DefaultToolCallingManagerTests {
|
||||
|
||||
@@ -317,6 +321,49 @@ class DefaultToolCallingManagerTests {
|
||||
assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenMixedMethodToolCallsInChatResponseThenExecute() throws NoSuchMethodException {
|
||||
ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder().build();
|
||||
|
||||
ToolDefinition toolDefinitionA = ToolDefinition.builder().name("toolA").inputSchema("{}").build();
|
||||
Method methodA = TestGenericClass.class.getMethod("call", String.class);
|
||||
MethodToolCallback methodToolCallback = MethodToolCallback.builder()
|
||||
.toolDefinition(toolDefinitionA)
|
||||
.toolMethod(methodA)
|
||||
.toolObject(new TestGenericClass())
|
||||
.build();
|
||||
|
||||
ToolDefinition toolDefinitionB = ToolDefinition.builder().name("toolB").inputSchema("{}").build();
|
||||
Method methodB = TestGenericClass.class.getMethod("callWithToolContext", ToolContext.class);
|
||||
MethodToolCallback methodToolCallbackNeedToolContext = MethodToolCallback.builder()
|
||||
.toolDefinition(toolDefinitionB)
|
||||
.toolMethod(methodB)
|
||||
.toolObject(new TestGenericClass())
|
||||
.build();
|
||||
|
||||
Prompt prompt = new Prompt(new UserMessage("Hello"),
|
||||
ToolCallingChatOptions.builder()
|
||||
.toolCallbacks(methodToolCallback, methodToolCallbackNeedToolContext)
|
||||
.toolNames("toolA", "toolB")
|
||||
.toolContext("key", "value")
|
||||
.build());
|
||||
|
||||
ChatResponse chatResponse = ChatResponse.builder()
|
||||
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
|
||||
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"),
|
||||
new AssistantMessage.ToolCall("toolB", "function", "toolB", "{}"))))))
|
||||
.build();
|
||||
|
||||
ToolResponseMessage expectedToolResponse = new ToolResponseMessage(
|
||||
List.of(new ToolResponseMessage.ToolResponse("toolA", "toolA", TestGenericClass.CALL_RESULT_JSON),
|
||||
new ToolResponseMessage.ToolResponse("toolB", "toolB",
|
||||
TestGenericClass.CALL_WITH_TOOL_CONTEXT_RESULT_JSON)));
|
||||
|
||||
ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);
|
||||
|
||||
assertThat(toolExecutionResult.conversationHistory()).contains(expectedToolResponse);
|
||||
}
|
||||
|
||||
static class TestToolCallback implements ToolCallback {
|
||||
|
||||
private final ToolDefinition toolDefinition;
|
||||
@@ -370,4 +417,31 @@ class DefaultToolCallingManagerTests {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Test class with methods that use generic types.
|
||||
*/
|
||||
static class TestGenericClass {
|
||||
|
||||
public final static String CALL_RESULT_JSON = """
|
||||
{
|
||||
"result": "Mission accomplished!"
|
||||
}
|
||||
""";
|
||||
|
||||
public final static String CALL_WITH_TOOL_CONTEXT_RESULT_JSON = """
|
||||
{
|
||||
"result": "ToolContext mission accomplished!"
|
||||
}
|
||||
""";
|
||||
|
||||
public String call(String toolInput) {
|
||||
return CALL_RESULT_JSON;
|
||||
}
|
||||
|
||||
public String callWithToolContext(ToolContext toolContext) {
|
||||
return CALL_WITH_TOOL_CONTEXT_RESULT_JSON;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user