Additional tests for the #1878
Add additional integration tests to ensure that the #1878 issue is resolved Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
This commit is contained in:
@@ -20,17 +20,21 @@ import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
|
||||
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
|
||||
import org.springframework.ai.bedrock.converse.RequiresAwsCredentials;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.ChatClient.StreamResponseSpec;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
@@ -175,7 +179,7 @@ public class BedrockNovaChatClientIT {
|
||||
|
||||
// https://github.com/spring-projects/spring-ai/issues/1878
|
||||
@Test
|
||||
void toolAnnotationWeatherForecastTest() {
|
||||
void toolAnnotationWeatherForecast() {
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
|
||||
|
||||
@@ -189,6 +193,27 @@ public class BedrockNovaChatClientIT {
|
||||
assertThat(response).contains("20 degrees");
|
||||
}
|
||||
|
||||
@Test
|
||||
void toolAnnotationWeatherForecastStreaming() {
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
|
||||
|
||||
Flux<ChatResponse> responses = chatClient.prompt()
|
||||
.tools(new DummyWeatherForcastTools())
|
||||
.user("Get current weather in Amsterdam")
|
||||
.stream()
|
||||
.chatResponse();
|
||||
|
||||
String content = responses.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.filter(cr -> cr.getResult() != null)
|
||||
.map(cr -> cr.getResult().getOutput().getText())
|
||||
.collect(Collectors.joining());
|
||||
|
||||
assertThat(content).contains("20 degrees");
|
||||
}
|
||||
|
||||
public static class DummyWeatherForcastTools {
|
||||
|
||||
@Tool(description = "Get the current weather forcast in Amsterdam")
|
||||
@@ -217,6 +242,30 @@ public class BedrockNovaChatClientIT {
|
||||
assertThat(response.temp()).isEqualTo(30.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void supplierBasedToolCallingStreaming() {
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
|
||||
|
||||
Flux<ChatResponse> responses = chatClient.prompt()
|
||||
.toolCallbacks(FunctionToolCallback.builder("weather", new WeatherService())
|
||||
.description("Get the current weather")
|
||||
.inputType(Void.class)
|
||||
.build())
|
||||
.user("Get current weather in Amsterdam")
|
||||
.stream()
|
||||
.chatResponse();
|
||||
|
||||
String content = responses.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.filter(cr -> cr.getResult() != null)
|
||||
.map(cr -> cr.getResult().getOutput().getText())
|
||||
.collect(Collectors.joining());
|
||||
|
||||
assertThat(content).contains("30.0");
|
||||
}
|
||||
|
||||
public static class WeatherService implements Supplier<WeatherService.Response> {
|
||||
|
||||
public record Response(double temp) {
|
||||
|
||||
Reference in New Issue
Block a user