Additional tests for the #1878

Add additional integration tests to ensure that the #1878 issue is resolved

Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
This commit is contained in:
Christian Tzolov
2025-05-10 16:51:37 +02:00
parent f30ab261b4
commit ad783d9867

View File

@@ -20,17 +20,21 @@ import java.io.IOException;
import java.time.Duration;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
import org.springframework.ai.bedrock.converse.RequiresAwsCredentials;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClient.StreamResponseSpec;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.annotation.Tool;
@@ -175,7 +179,7 @@ public class BedrockNovaChatClientIT {
// https://github.com/spring-projects/spring-ai/issues/1878
@Test
void toolAnnotationWeatherForecastTest() {
void toolAnnotationWeatherForecast() {
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
@@ -189,6 +193,27 @@ public class BedrockNovaChatClientIT {
assertThat(response).contains("20 degrees");
}
@Test
void toolAnnotationWeatherForecastStreaming() {
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
Flux<ChatResponse> responses = chatClient.prompt()
.tools(new DummyWeatherForcastTools())
.user("Get current weather in Amsterdam")
.stream()
.chatResponse();
String content = responses.collectList()
.block()
.stream()
.filter(cr -> cr.getResult() != null)
.map(cr -> cr.getResult().getOutput().getText())
.collect(Collectors.joining());
assertThat(content).contains("20 degrees");
}
public static class DummyWeatherForcastTools {
@Tool(description = "Get the current weather forcast in Amsterdam")
@@ -217,6 +242,30 @@ public class BedrockNovaChatClientIT {
assertThat(response.temp()).isEqualTo(30.0);
}
@Test
void supplierBasedToolCallingStreaming() {
ChatClient chatClient = ChatClient.builder(this.chatModel).build();
Flux<ChatResponse> responses = chatClient.prompt()
.toolCallbacks(FunctionToolCallback.builder("weather", new WeatherService())
.description("Get the current weather")
.inputType(Void.class)
.build())
.user("Get current weather in Amsterdam")
.stream()
.chatResponse();
String content = responses.collectList()
.block()
.stream()
.filter(cr -> cr.getResult() != null)
.map(cr -> cr.getResult().getOutput().getText())
.collect(Collectors.joining());
assertThat(content).contains("30.0");
}
public static class WeatherService implements Supplier<WeatherService.Response> {
public record Response(double temp) {