committed by
Ilayaperumal Gopinathan
parent
029e8a10af
commit
b6f29a493f
@@ -24,7 +24,6 @@ import io.modelcontextprotocol.client.McpSyncClient;
|
||||
import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider;
|
||||
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
|
||||
import org.springframework.ai.mcp.client.autoconfigure.properties.McpClientCommonProperties;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.AllNestedConditions;
|
||||
|
||||
@@ -129,7 +129,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
|
||||
}
|
||||
|
||||
public String getPathPrefix() {
|
||||
return pathPrefix;
|
||||
return this.pathPrefix;
|
||||
}
|
||||
|
||||
public void setPathPrefix(String pathPrefix) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Timestamp;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
@@ -78,7 +77,7 @@ public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
@Override
|
||||
public List<String> findConversationIds() {
|
||||
return this.jdbcTemplate.queryForList(dialect.getSelectConversationIdsSql(), String.class);
|
||||
return this.jdbcTemplate.queryForList(this.dialect.getSelectConversationIdsSql(), String.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -41,13 +41,6 @@ import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.test.context.ContextConfiguration;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
@@ -78,7 +71,7 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.of(message));
|
||||
|
||||
assertThat(chatMemoryRepository.findConversationIds()).contains(conversationId);
|
||||
assertThat(this.chatMemoryRepository.findConversationIds()).contains(conversationId);
|
||||
|
||||
// Use dialect to get the appropriate SQL query
|
||||
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
|
||||
@@ -103,7 +96,7 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
|
||||
this.chatMemoryRepository.saveAll(conversationId, messages);
|
||||
|
||||
assertThat(chatMemoryRepository.findConversationIds()).contains(conversationId);
|
||||
assertThat(this.chatMemoryRepository.findConversationIds()).contains(conversationId);
|
||||
|
||||
// Use dialect to get the appropriate SQL query
|
||||
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
|
||||
@@ -179,10 +172,10 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
|
||||
// Save messages in the expected order
|
||||
List<Message> orderedMessages = List.of(firstMessage, secondMessage, thirdMessage, fourthMessage);
|
||||
chatMemoryRepository.saveAll(conversationId, orderedMessages);
|
||||
this.chatMemoryRepository.saveAll(conversationId, orderedMessages);
|
||||
|
||||
// Retrieve messages using the repository
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(4);
|
||||
|
||||
// Get the actual order from the retrieved messages
|
||||
|
||||
@@ -791,7 +791,7 @@ public class BedrockProxyChatModel implements ChatModel {
|
||||
|
||||
private Builder() {
|
||||
try {
|
||||
region = DefaultAwsRegionProviderChain.builder().build().getRegion();
|
||||
this.region = DefaultAwsRegionProviderChain.builder().build().getRegion();
|
||||
}
|
||||
catch (SdkClientException e) {
|
||||
logger.warn("Failed to load region from DefaultAwsRegionProviderChain, using US_EAST_1", e);
|
||||
|
||||
@@ -37,9 +37,9 @@ class BedrockProxyChatModelTest {
|
||||
@Test
|
||||
void shouldIgnoreExceptionAndUseDefault() {
|
||||
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
|
||||
when(awsRegionProviderBuilder.build().getRegion())
|
||||
when(this.awsRegionProviderBuilder.build().getRegion())
|
||||
.thenThrow(SdkClientException.builder().message("failed load").build());
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(this.awsRegionProviderBuilder);
|
||||
BedrockProxyChatModel.builder().build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.ObjectUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Sinks;
|
||||
import reactor.core.publisher.Sinks.EmitFailureHandler;
|
||||
@@ -50,6 +49,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.ResponseStream;
|
||||
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.ObjectUtils;
|
||||
|
||||
/**
|
||||
* Abstract class for the Bedrock API. It provides the basic functionality to invoke the chat completion model and
|
||||
@@ -322,6 +322,20 @@ public abstract class AbstractBedrockApi<I, O, SO> {
|
||||
return eventSink.asFlux();
|
||||
}
|
||||
|
||||
private Region getRegion(Region region) {
|
||||
if (ObjectUtils.isEmpty(region)) {
|
||||
try {
|
||||
return DefaultAwsRegionProviderChain.builder().build().getRegion();
|
||||
}
|
||||
catch (SdkClientException e) {
|
||||
throw new IllegalArgumentException("Region is empty and cannot be loaded from DefaultAwsRegionProviderChain: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
else {
|
||||
return region;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Encapsulates the metrics about the model invocation.
|
||||
* https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-claude.html
|
||||
@@ -341,16 +355,5 @@ public abstract class AbstractBedrockApi<I, O, SO> {
|
||||
@JsonProperty("invocationLatency") Long invocationLatency) {
|
||||
}
|
||||
|
||||
private Region getRegion(Region region) {
|
||||
if (ObjectUtils.isEmpty(region)) {
|
||||
try {
|
||||
return DefaultAwsRegionProviderChain.builder().build().getRegion();
|
||||
} catch (SdkClientException e) {
|
||||
throw new IllegalArgumentException("Region is empty and cannot be loaded from DefaultAwsRegionProviderChain: " + e.getMessage(), e);
|
||||
}
|
||||
} else {
|
||||
return region;
|
||||
}
|
||||
}
|
||||
}
|
||||
// @formatter:on
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package org.springframework.ai.bedrock.api;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
@@ -28,11 +30,11 @@ import software.amazon.awssdk.core.exception.SdkClientException;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.mockStatic;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class AbstractBedrockApiTest {
|
||||
@@ -49,10 +51,10 @@ class AbstractBedrockApiTest {
|
||||
@Test
|
||||
void shouldLoadRegionFromAwsDefaults() {
|
||||
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
|
||||
when(awsRegionProviderBuilder.build().getRegion()).thenReturn(Region.AF_SOUTH_1);
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
|
||||
when(this.awsRegionProviderBuilder.build().getRegion()).thenReturn(Region.AF_SOUTH_1);
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(this.awsRegionProviderBuilder);
|
||||
AbstractBedrockApi<Object, Object, Object> testBedrockApi = new TestBedrockApi("modelId",
|
||||
awsCredentialsProvider, null, objectMapper, Duration.ofMinutes(5));
|
||||
this.awsCredentialsProvider, null, this.objectMapper, Duration.ofMinutes(5));
|
||||
assertThat(testBedrockApi.getRegion()).isEqualTo(Region.AF_SOUTH_1);
|
||||
}
|
||||
}
|
||||
@@ -60,10 +62,10 @@ class AbstractBedrockApiTest {
|
||||
@Test
|
||||
void shouldThrowIllegalArgumentIfAwsDefaultsFailed() {
|
||||
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
|
||||
when(awsRegionProviderBuilder.build().getRegion())
|
||||
when(this.awsRegionProviderBuilder.build().getRegion())
|
||||
.thenThrow(SdkClientException.builder().message("failed load").build());
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
|
||||
assertThatThrownBy(() -> new TestBedrockApi("modelId", awsCredentialsProvider, null, objectMapper,
|
||||
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(this.awsRegionProviderBuilder);
|
||||
assertThatThrownBy(() -> new TestBedrockApi("modelId", this.awsCredentialsProvider, null, this.objectMapper,
|
||||
Duration.ofMinutes(5)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("failed load");
|
||||
|
||||
@@ -16,6 +16,9 @@
|
||||
|
||||
package org.springframework.ai.chat.evaluation;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.evaluation.EvaluationRequest;
|
||||
@@ -24,9 +27,6 @@ import org.springframework.ai.evaluation.Evaluator;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Evaluates the relevancy of a response to a query based on the context provided.
|
||||
*/
|
||||
@@ -91,7 +91,7 @@ public class RelevancyEvaluator implements Evaluator {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
public static final class Builder {
|
||||
|
||||
private ChatClient.Builder chatClientBuilder;
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package org.springframework.ai.chat.evaluation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.rag.Query;
|
||||
import org.springframework.ai.rag.util.PromptAssert;
|
||||
|
||||
@@ -21,6 +21,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.rag.Query;
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.fail;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
Reference in New Issue
Block a user