Fixing miscellaneous checkstyle errors
Enabling checkstyle by default in the project build Signed-off-by: Soby Chacko <soby.chacko@broadcom.com>
This commit is contained in:
committed by
Ilayaperumal Gopinathan
parent
31feb4319b
commit
368be3a04f
@@ -21,14 +21,14 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.document.Document;
|
||||
|
||||
@@ -21,7 +21,6 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.util.Assert;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
@@ -38,6 +37,7 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Memory is retrieved from a VectorStore added into the prompt's system text.
|
||||
@@ -50,7 +50,7 @@ import org.springframework.ai.vectorstore.VectorStore;
|
||||
* @author Mark Pollack
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
|
||||
public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
|
||||
|
||||
public static final String TOP_K = "chat_memory_vector_store_top_k";
|
||||
|
||||
@@ -104,7 +104,7 @@ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return order;
|
||||
return this.order;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,7 +1,24 @@
|
||||
/*
|
||||
* Copyright 2025-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.vectorstore;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
@@ -109,33 +109,6 @@ public class McpClientCommonProperties {
|
||||
*/
|
||||
private Toolcallback toolcallback = new Toolcallback();
|
||||
|
||||
/**
|
||||
* Represents a callback configuration for tools.
|
||||
* <p>
|
||||
* This record is used to encapsulate the configuration for enabling or disabling tool
|
||||
* callbacks in the MCP client.
|
||||
*
|
||||
* @param enabled A boolean flag indicating whether the tool callback is enabled. If
|
||||
* true, the tool callback is active; otherwise, it is disabled.
|
||||
*/
|
||||
public static class Toolcallback {
|
||||
|
||||
/**
|
||||
* A boolean flag indicating whether the tool callback is enabled. If true, the
|
||||
* tool callback is active; otherwise, it is disabled.
|
||||
*/
|
||||
private boolean enabled = true;
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
return this.enabled;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
return this.enabled;
|
||||
}
|
||||
@@ -193,11 +166,38 @@ public class McpClientCommonProperties {
|
||||
}
|
||||
|
||||
public Toolcallback getToolcallback() {
|
||||
return toolcallback;
|
||||
return this.toolcallback;
|
||||
}
|
||||
|
||||
public void setToolcallback(Toolcallback toolcallback) {
|
||||
this.toolcallback = toolcallback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a callback configuration for tools.
|
||||
* <p>
|
||||
* This record is used to encapsulate the configuration for enabling or disabling tool
|
||||
* callbacks in the MCP client.
|
||||
*
|
||||
* @param enabled A boolean flag indicating whether the tool callback is enabled. If
|
||||
* true, the tool callback is active; otherwise, it is disabled.
|
||||
*/
|
||||
public static class Toolcallback {
|
||||
|
||||
/**
|
||||
* A boolean flag indicating whether the tool callback is enabled. If true, the
|
||||
* tool callback is active; otherwise, it is disabled.
|
||||
*/
|
||||
private boolean enabled = true;
|
||||
|
||||
public void setEnabled(boolean enabled) {
|
||||
this.enabled = enabled;
|
||||
}
|
||||
|
||||
public boolean isEnabled() {
|
||||
return this.enabled;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,13 +16,14 @@
|
||||
|
||||
package org.springframework.ai.mcp.client.autoconfigure.properties;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
|
||||
@@ -16,13 +16,14 @@
|
||||
|
||||
package org.springframework.ai.mcp.client.autoconfigure.properties;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
|
||||
@@ -351,4 +351,4 @@ public class McpServerAutoConfiguration {
|
||||
return serverBuilder.build();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.mcp.server.autoconfigure;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.autoconfigure.jackson.JacksonAutoConfiguration;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
@@ -25,8 +27,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.web.reactive.function.server.RouterFunction;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class McpWebFluxServerAutoConfigurationTests {
|
||||
|
||||
@@ -36,7 +37,7 @@ class McpWebFluxServerAutoConfigurationTests {
|
||||
|
||||
@Test
|
||||
void shouldConfigureWebFluxTransportWithCustomObjectMapper() {
|
||||
this.contextRunner.run((context) -> {
|
||||
this.contextRunner.run(context -> {
|
||||
assertThat(context).hasSingleBean(WebFluxSseServerTransportProvider.class);
|
||||
assertThat(context).hasSingleBean(RouterFunction.class);
|
||||
assertThat(context).hasSingleBean(McpServerProperties.class);
|
||||
@@ -48,6 +49,7 @@ class McpWebFluxServerAutoConfigurationTests {
|
||||
.isEnabled(com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)).isFalse();
|
||||
|
||||
// Test with a JSON payload containing unknown fields
|
||||
// CHECKSTYLE:OFF
|
||||
String jsonWithUnknownField = """
|
||||
{
|
||||
"tools": ["tool1", "tool2"],
|
||||
@@ -55,6 +57,7 @@ class McpWebFluxServerAutoConfigurationTests {
|
||||
"unknownField": "value"
|
||||
}
|
||||
""";
|
||||
// CHECKSTYLE:ON
|
||||
|
||||
// This should not throw an exception
|
||||
TestMessage message = objectMapper.readValue(jsonWithUnknownField, TestMessage.class);
|
||||
@@ -75,7 +78,7 @@ class McpWebFluxServerAutoConfigurationTests {
|
||||
private String name;
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
return this.name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
|
||||
@@ -20,6 +20,7 @@ import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.ChatClientCustomizer;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
|
||||
@@ -30,7 +31,11 @@ import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.*;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@@ -19,6 +19,7 @@ package org.springframework.ai.model.chat.client.autoconfigure;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationHandler;
|
||||
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
|
||||
|
||||
@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.repository.cassandra.autoconfig
|
||||
|
||||
import com.datastax.oss.driver.api.core.CqlSession;
|
||||
|
||||
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig;
|
||||
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig;
|
||||
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;
|
||||
|
||||
@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.repository.jdbc.autoconfigure;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepositoryDialect;
|
||||
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepositoryDialect;
|
||||
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
|
||||
@@ -59,7 +59,7 @@ public class JdbcChatMemoryRepositoryProperties {
|
||||
}
|
||||
|
||||
public String getPlatform() {
|
||||
return platform;
|
||||
return this.platform;
|
||||
}
|
||||
|
||||
public void setPlatform(String platform) {
|
||||
|
||||
@@ -74,11 +74,11 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
|
||||
|
||||
// Debug: Print current schemas and tables
|
||||
try {
|
||||
List<String> schemas = jdbcTemplate.queryForList("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA",
|
||||
String.class);
|
||||
List<String> schemas = this.jdbcTemplate
|
||||
.queryForList("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA", String.class);
|
||||
System.out.println("Available schemas: " + schemas);
|
||||
|
||||
List<String> tables = jdbcTemplate
|
||||
List<String> tables = this.jdbcTemplate
|
||||
.queryForList("SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES", String.class);
|
||||
System.out.println("Available tables: " + tables);
|
||||
}
|
||||
@@ -89,22 +89,22 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
|
||||
// Try a more direct approach with explicit SQL statements
|
||||
try {
|
||||
// Drop the table first if it exists to avoid any conflicts
|
||||
jdbcTemplate.execute("DROP TABLE SPRING_AI_CHAT_MEMORY IF EXISTS");
|
||||
this.jdbcTemplate.execute("DROP TABLE SPRING_AI_CHAT_MEMORY IF EXISTS");
|
||||
System.out.println("Dropped existing table if it existed");
|
||||
|
||||
// Create the table with a simplified schema
|
||||
jdbcTemplate.execute("CREATE TABLE SPRING_AI_CHAT_MEMORY (" + "conversation_id VARCHAR(36) NOT NULL, "
|
||||
+ "content LONGVARCHAR NOT NULL, " + "type VARCHAR(10) NOT NULL, "
|
||||
+ "timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)");
|
||||
this.jdbcTemplate.execute("CREATE TABLE SPRING_AI_CHAT_MEMORY ("
|
||||
+ "conversation_id VARCHAR(36) NOT NULL, " + "content LONGVARCHAR NOT NULL, "
|
||||
+ "type VARCHAR(10) NOT NULL, " + "timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)");
|
||||
System.out.println("Created table with simplified schema");
|
||||
|
||||
// Create index
|
||||
jdbcTemplate.execute(
|
||||
this.jdbcTemplate.execute(
|
||||
"CREATE INDEX SPRING_AI_CHAT_MEMORY_IDX ON SPRING_AI_CHAT_MEMORY(conversation_id, timestamp DESC)");
|
||||
System.out.println("Created index");
|
||||
|
||||
// Verify table was created
|
||||
boolean tableExists = jdbcTemplate.queryForObject(
|
||||
boolean tableExists = this.jdbcTemplate.queryForObject(
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'SPRING_AI_CHAT_MEMORY'",
|
||||
Integer.class) > 0;
|
||||
System.out.println("Table SPRING_AI_CHAT_MEMORY exists after creation: " + tableExists);
|
||||
@@ -125,7 +125,7 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
|
||||
@Test
|
||||
public void useAutoConfiguredChatMemoryWithJdbc() {
|
||||
// Check that the custom schema initializer is present
|
||||
assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue();
|
||||
assertThat(this.context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue();
|
||||
|
||||
// Debug: List all schema-hsqldb.sql resources on the classpath
|
||||
try {
|
||||
@@ -144,7 +144,7 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
|
||||
|
||||
// Verify the table exists by executing a direct query
|
||||
try {
|
||||
boolean tableExists = jdbcTemplate.queryForObject(
|
||||
boolean tableExists = this.jdbcTemplate.queryForObject(
|
||||
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'SPRING_AI_CHAT_MEMORY'",
|
||||
Integer.class) > 0;
|
||||
System.out.println("Table SPRING_AI_CHAT_MEMORY exists: " + tableExists);
|
||||
@@ -157,10 +157,10 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
|
||||
}
|
||||
|
||||
// Now test the ChatMemory functionality
|
||||
assertThat(context.getBean(org.springframework.ai.chat.memory.ChatMemory.class)).isNotNull();
|
||||
assertThat(context.getBean(JdbcChatMemoryRepository.class)).isNotNull();
|
||||
assertThat(this.context.getBean(org.springframework.ai.chat.memory.ChatMemory.class)).isNotNull();
|
||||
assertThat(this.context.getBean(JdbcChatMemoryRepository.class)).isNotNull();
|
||||
|
||||
var chatMemory = context.getBean(org.springframework.ai.chat.memory.ChatMemory.class);
|
||||
var chatMemory = this.context.getBean(org.springframework.ai.chat.memory.ChatMemory.class);
|
||||
var conversationId = java.util.UUID.randomUUID().toString();
|
||||
var userMessage = new UserMessage("Message from the user");
|
||||
|
||||
|
||||
@@ -55,12 +55,14 @@ class JdbcChatMemoryRepositoryPostgresqlAutoConfigurationIT {
|
||||
|
||||
@Test
|
||||
void jdbcChatMemoryScriptDatabaseInitializer_shouldNotRunSchemaInit() {
|
||||
// CHECKSTYLE:OFF
|
||||
this.contextRunner.withPropertyValues("spring.ai.chat.memory.repository.jdbc.initialize-schema=never")
|
||||
.run(context -> {
|
||||
assertThat(context).doesNotHaveBean("jdbcChatMemoryScriptDatabaseInitializer");
|
||||
// Optionally, check that the schema is not initialized (could check table
|
||||
// absence if needed)
|
||||
});
|
||||
// CHECKSTYLE:ON
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -1,6 +1,19 @@
|
||||
/*
|
||||
* Integration test for SQL Server using Testcontainers, following the same structure as the PostgreSQL test.
|
||||
* Copyright 2024-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.model.chat.memory.repository.jdbc.autoconfigure;
|
||||
|
||||
import java.time.Duration;
|
||||
@@ -8,6 +21,10 @@ import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testcontainers.containers.MSSQLServerContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
|
||||
@@ -19,13 +36,12 @@ import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.testcontainers.containers.MSSQLServerContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/*
|
||||
* Integration test for SQL Server using Testcontainers, following the same structure as the PostgreSQL test.
|
||||
*/
|
||||
@Testcontainers
|
||||
class JdbcChatMemoryRepositorySqlServerAutoConfigurationIT {
|
||||
|
||||
@@ -58,9 +74,7 @@ class JdbcChatMemoryRepositorySqlServerAutoConfigurationIT {
|
||||
@Test
|
||||
void jdbcChatMemoryScriptDatabaseInitializer_shouldNotRunSchemaInit() {
|
||||
this.contextRunner.withPropertyValues("spring.ai.chat.memory.repository.jdbc.initialize-schema=never")
|
||||
.run(context -> {
|
||||
assertThat(context).doesNotHaveBean("jdbcChatMemoryScriptDatabaseInitializer");
|
||||
});
|
||||
.run(context -> assertThat(context).doesNotHaveBean("jdbcChatMemoryScriptDatabaseInitializer"));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.repository.neo4j.autoconfigure;
|
||||
|
||||
import org.neo4j.driver.Driver;
|
||||
|
||||
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepositoryConfig;
|
||||
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepositoryConfig;
|
||||
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
|
||||
@@ -23,14 +23,13 @@ import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepository;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepositoryConfig;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package org.springframework.ai.model.chat.memory.autoconfigure;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
|
||||
@@ -40,7 +41,7 @@ class ChatMemoryAutoConfigurationTests {
|
||||
|
||||
@Test
|
||||
void defaultConfiguration() {
|
||||
contextRunner.run(context -> {
|
||||
this.contextRunner.run(context -> {
|
||||
assertThat(context).hasSingleBean(ChatMemoryRepository.class);
|
||||
assertThat(context).hasSingleBean(ChatMemory.class);
|
||||
});
|
||||
@@ -48,7 +49,7 @@ class ChatMemoryAutoConfigurationTests {
|
||||
|
||||
@Test
|
||||
void whenChatMemoryRepositoryExists() {
|
||||
contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
|
||||
this.contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
|
||||
assertThat(context).hasSingleBean(ChatMemoryRepository.class);
|
||||
assertThat(context).hasBean("customChatMemoryRepository");
|
||||
assertThat(context).doesNotHaveBean("chatMemoryRepository");
|
||||
@@ -57,7 +58,7 @@ class ChatMemoryAutoConfigurationTests {
|
||||
|
||||
@Test
|
||||
void whenChatMemoryExists() {
|
||||
contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
|
||||
this.contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
|
||||
assertThat(context).hasSingleBean(ChatMemoryRepository.class);
|
||||
assertThat(context).hasBean("customChatMemoryRepository");
|
||||
assertThat(context).doesNotHaveBean("chatMemoryRepository");
|
||||
@@ -71,7 +72,7 @@ class ChatMemoryAutoConfigurationTests {
|
||||
|
||||
@Bean
|
||||
ChatMemoryRepository customChatMemoryRepository() {
|
||||
return customChatMemoryRepository;
|
||||
return this.customChatMemoryRepository;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -83,7 +84,7 @@ class ChatMemoryAutoConfigurationTests {
|
||||
|
||||
@Bean
|
||||
ChatMemory customChatMemory() {
|
||||
return customChatMemory;
|
||||
return this.customChatMemory;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,10 +16,13 @@
|
||||
|
||||
package org.springframework.ai.model.chat.observation.autoconfigure;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import io.micrometer.core.instrument.MeterRegistry;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
@@ -33,13 +36,15 @@ import org.springframework.ai.model.observation.ErrorLoggingObservationHandler;
|
||||
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.*;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Auto-configuration for Spring AI chat model observations.
|
||||
*
|
||||
|
||||
@@ -16,10 +16,13 @@
|
||||
|
||||
package org.springframework.ai.model.chat.observation.autoconfigure;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import io.micrometer.core.instrument.composite.CompositeMeterRegistry;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
|
||||
import org.springframework.ai.chat.observation.ChatModelCompletionObservationHandler;
|
||||
import org.springframework.ai.chat.observation.ChatModelMeterObservationHandler;
|
||||
@@ -35,8 +38,6 @@ import org.springframework.boot.test.system.OutputCaptureExtension;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
|
||||
@@ -19,12 +19,17 @@ package org.springframework.ai.model.image.observation.autoconfigure;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.image.observation.ImageModelObservationContext;
|
||||
import org.springframework.ai.image.observation.ImageModelPromptContentObservationHandler;
|
||||
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.*;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@@ -19,6 +19,7 @@ package org.springframework.ai.model.image.observation.autoconfigure;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import org.springframework.ai.image.observation.ImageModelObservationContext;
|
||||
import org.springframework.ai.image.observation.ImageModelPromptContentObservationHandler;
|
||||
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
package org.springframework.ai.model.bedrock.titan.autoconfigure;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
|
||||
|
||||
@@ -17,10 +17,11 @@
|
||||
package org.springframework.ai.model.deepseek.autoconfigure;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
|
||||
import org.springframework.ai.model.SimpleApiKey;
|
||||
import org.springframework.ai.deepseek.DeepSeekChatModel;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
import org.springframework.ai.model.SimpleApiKey;
|
||||
import org.springframework.ai.model.SpringAIModelProperties;
|
||||
import org.springframework.ai.model.SpringAIModels;
|
||||
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
|
||||
|
||||
@@ -71,7 +71,7 @@ public class DeepSeekChatProperties extends DeepSeekParentProperties {
|
||||
}
|
||||
|
||||
public String getCompletionsPath() {
|
||||
return completionsPath;
|
||||
return this.completionsPath;
|
||||
}
|
||||
|
||||
public void setCompletionsPath(String completionsPath) {
|
||||
@@ -79,7 +79,7 @@ public class DeepSeekChatProperties extends DeepSeekParentProperties {
|
||||
}
|
||||
|
||||
public String getBetaPrefixPath() {
|
||||
return betaPrefixPath;
|
||||
return this.betaPrefixPath;
|
||||
}
|
||||
|
||||
public void setBetaPrefixPath(String betaPrefixPath) {
|
||||
|
||||
@@ -16,10 +16,15 @@
|
||||
|
||||
package org.springframework.ai.model.deepseek.autoconfigure;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
@@ -28,10 +33,6 @@ import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package org.springframework.ai.model.deepseek.autoconfigure;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.deepseek.DeepSeekChatModel;
|
||||
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
|
||||
@@ -16,10 +16,16 @@
|
||||
|
||||
package org.springframework.ai.model.deepseek.autoconfigure.tool;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
@@ -36,11 +42,6 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfigura
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
|
||||
@@ -16,10 +16,15 @@
|
||||
|
||||
package org.springframework.ai.model.deepseek.autoconfigure.tool;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
@@ -31,10 +36,6 @@ import org.springframework.ai.model.deepseek.autoconfigure.DeepSeekChatAutoConfi
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
|
||||
@@ -16,10 +16,16 @@
|
||||
|
||||
package org.springframework.ai.model.deepseek.autoconfigure.tool;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
@@ -34,11 +40,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.Description;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
|
||||
@@ -16,14 +16,14 @@
|
||||
|
||||
package org.springframework.ai.model.deepseek.autoconfigure.tool;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonClassDescription;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* Mock 3rd party weather service.
|
||||
*
|
||||
|
||||
@@ -54,7 +54,7 @@ public class OpenAiImageProperties extends OpenAiParentProperties {
|
||||
}
|
||||
|
||||
public String getImagesPath() {
|
||||
return imagesPath;
|
||||
return this.imagesPath;
|
||||
}
|
||||
|
||||
public void setImagesPath(String imagesPath) {
|
||||
|
||||
@@ -16,9 +16,13 @@
|
||||
|
||||
package org.springframework.ai.model.tool.autoconfigure;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
@@ -40,9 +44,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.support.GenericApplicationContext;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Auto-configuration for common tool calling features of {@link ChatModel}.
|
||||
*
|
||||
|
||||
@@ -39,7 +39,7 @@ public class ToolCallingProperties {
|
||||
private boolean includeContent = false;
|
||||
|
||||
public boolean isIncludeContent() {
|
||||
return includeContent;
|
||||
return this.includeContent;
|
||||
}
|
||||
|
||||
public void setIncludeContent(boolean includeContent) {
|
||||
|
||||
@@ -116,9 +116,7 @@ class ToolCallingAutoConfigurationTests {
|
||||
void observationFilterDefault() {
|
||||
new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class))
|
||||
.withUserConfiguration(Config.class)
|
||||
.run(context -> {
|
||||
assertThat(context).doesNotHaveBean(ToolCallingContentObservationFilter.class);
|
||||
});
|
||||
.run(context -> assertThat(context).doesNotHaveBean(ToolCallingContentObservationFilter.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -126,9 +124,7 @@ class ToolCallingAutoConfigurationTests {
|
||||
new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class))
|
||||
.withPropertyValues("spring.ai.tools.observations.include-content=true")
|
||||
.withUserConfiguration(Config.class)
|
||||
.run(context -> {
|
||||
assertThat(context).hasSingleBean(ToolCallingContentObservationFilter.class);
|
||||
});
|
||||
.run(context -> assertThat(context).hasSingleBean(ToolCallingContentObservationFilter.class));
|
||||
}
|
||||
|
||||
static class WeatherService {
|
||||
|
||||
@@ -64,7 +64,7 @@ public class CosmosDBVectorStoreAutoConfiguration {
|
||||
}
|
||||
|
||||
CosmosClientBuilder builder = new CosmosClientBuilder().endpoint(properties.getEndpoint())
|
||||
.userAgentSuffix(agentSuffix);
|
||||
.userAgentSuffix(this.agentSuffix);
|
||||
|
||||
if (properties.getKey() == null || properties.getKey().isEmpty()) {
|
||||
builder.credential(new DefaultAzureCredentialBuilder().build());
|
||||
|
||||
@@ -16,12 +16,12 @@
|
||||
|
||||
package org.springframework.ai.vectorstore.cosmosdb.autoconfigure;
|
||||
|
||||
import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
|
||||
/**
|
||||
* Configuration properties for CosmosDB Vector Store.
|
||||
*
|
||||
|
||||
@@ -39,7 +39,7 @@ public class ChromaVectorStoreProperties extends CommonVectorStoreProperties {
|
||||
private String collectionName = ChromaApiConstants.DEFAULT_COLLECTION_NAME;
|
||||
|
||||
public String getTenantName() {
|
||||
return tenantName;
|
||||
return this.tenantName;
|
||||
}
|
||||
|
||||
public void setTenantName(String tenantName) {
|
||||
@@ -47,7 +47,7 @@ public class ChromaVectorStoreProperties extends CommonVectorStoreProperties {
|
||||
}
|
||||
|
||||
public String getDatabaseName() {
|
||||
return databaseName;
|
||||
return this.databaseName;
|
||||
}
|
||||
|
||||
public void setDatabaseName(String databaseName) {
|
||||
|
||||
@@ -19,12 +19,17 @@ package org.springframework.ai.vectorstore.observation.autoconfigure;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
|
||||
import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.*;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@@ -19,6 +19,7 @@ package org.springframework.ai.vectorstore.observation.autoconfigure;
|
||||
import io.micrometer.tracing.Tracer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
|
||||
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
|
||||
import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler;
|
||||
|
||||
@@ -52,7 +52,6 @@ import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.boot.ssl.SslBundles;
|
||||
|
||||
@@ -98,7 +98,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
|
||||
}
|
||||
|
||||
public String getSslBundle() {
|
||||
return sslBundle;
|
||||
return this.sslBundle;
|
||||
}
|
||||
|
||||
public void setSslBundle(String sslBundle) {
|
||||
@@ -106,7 +106,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
|
||||
}
|
||||
|
||||
public Duration getConnectionTimeout() {
|
||||
return connectionTimeout;
|
||||
return this.connectionTimeout;
|
||||
}
|
||||
|
||||
public void setConnectionTimeout(Duration connectionTimeout) {
|
||||
@@ -114,7 +114,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
|
||||
}
|
||||
|
||||
public Duration getReadTimeout() {
|
||||
return readTimeout;
|
||||
return this.readTimeout;
|
||||
}
|
||||
|
||||
public void setReadTimeout(Duration readTimeout) {
|
||||
|
||||
@@ -47,7 +47,7 @@ import org.springframework.util.Assert;
|
||||
* @author Mick Semb Wever
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class CassandraChatMemoryRepository implements ChatMemoryRepository {
|
||||
public final class CassandraChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
public static final String CONVERSATION_TS = CassandraChatMemoryRepository.class.getSimpleName()
|
||||
+ "_message_timestamp";
|
||||
@@ -125,7 +125,7 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
Instant instant = Instant.now();
|
||||
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
|
||||
BoundStatementBuilder builder = addStmt.boundStatementBuilder();
|
||||
BoundStatementBuilder builder = this.addStmt.boundStatementBuilder();
|
||||
|
||||
for (int k = 0; k < primaryKeys.size(); ++k) {
|
||||
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);
|
||||
|
||||
@@ -34,8 +34,6 @@ import com.datastax.oss.driver.api.core.type.UserDefinedType;
|
||||
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
|
||||
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
|
||||
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
|
||||
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn;
|
||||
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd;
|
||||
import com.datastax.oss.driver.api.querybuilder.schema.CreateTable;
|
||||
import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart;
|
||||
import com.datastax.oss.driver.api.querybuilder.schema.CreateTableWithOptions;
|
||||
@@ -140,13 +138,13 @@ public final class CassandraChatMemoryRepositoryConfig {
|
||||
Preconditions.checkState(this.session.getMetadata()
|
||||
.getKeyspace(this.schema.keyspace())
|
||||
.get()
|
||||
.getUserDefinedType(messageUDT)
|
||||
.getUserDefinedType(this.messageUDT)
|
||||
.isPresent(), "table %s does not exist");
|
||||
|
||||
UserDefinedType udt = this.session.getMetadata()
|
||||
.getKeyspace(this.schema.keyspace())
|
||||
.get()
|
||||
.getUserDefinedType(messageUDT)
|
||||
.getUserDefinedType(this.messageUDT)
|
||||
.get();
|
||||
|
||||
Preconditions.checkState(udt.contains(this.messageUdtTimestampColumn), "field %s does not exist",
|
||||
@@ -186,7 +184,7 @@ public final class CassandraChatMemoryRepositoryConfig {
|
||||
String lastClusteringColumn = this.schema.clusteringKeys.get(this.schema.clusteringKeys.size() - 1).name();
|
||||
|
||||
CreateTableWithOptions createTableWithOptions = createTable
|
||||
.withColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true)))
|
||||
.withColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(this.messageUDT, true)))
|
||||
.withClusteringOrder(lastClusteringColumn, ClusteringOrder.DESC)
|
||||
// TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() when
|
||||
// available
|
||||
@@ -201,11 +199,11 @@ public final class CassandraChatMemoryRepositoryConfig {
|
||||
|
||||
private void ensureMessageTypeExist() {
|
||||
|
||||
SimpleStatement stmt = SchemaBuilder.createType(messageUDT)
|
||||
SimpleStatement stmt = SchemaBuilder.createType(this.messageUDT)
|
||||
.ifNotExists()
|
||||
.withField(messageUdtTimestampColumn, DataTypes.TIMESTAMP)
|
||||
.withField(messageUdtTypeColumn, DataTypes.TEXT)
|
||||
.withField(messageUdtContentColumn, DataTypes.TEXT)
|
||||
.withField(this.messageUdtTimestampColumn, DataTypes.TIMESTAMP)
|
||||
.withField(this.messageUdtTypeColumn, DataTypes.TEXT)
|
||||
.withField(this.messageUdtContentColumn, DataTypes.TEXT)
|
||||
.build();
|
||||
|
||||
this.session.execute(stmt.setKeyspace(this.schema.keyspace));
|
||||
@@ -222,7 +220,7 @@ public final class CassandraChatMemoryRepositoryConfig {
|
||||
if (tableMetadata.getColumn(this.messagesColumn).isEmpty()) {
|
||||
|
||||
SimpleStatement stmt = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table())
|
||||
.addColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true)))
|
||||
.addColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(this.messageUDT, true)))
|
||||
.build();
|
||||
|
||||
logger.debug("Executing {}", stmt.getQuery());
|
||||
|
||||
@@ -26,7 +26,6 @@ import com.datastax.oss.driver.api.core.cql.ResultSet;
|
||||
import com.datastax.oss.driver.api.core.data.UdtValue;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.testcontainers.cassandra.CassandraContainer;
|
||||
|
||||
@@ -16,7 +16,10 @@
|
||||
|
||||
package org.springframework.ai.chat.memory.repository.jdbc;
|
||||
|
||||
import java.sql.*;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Timestamp;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -24,6 +27,9 @@ import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
@@ -37,15 +43,8 @@ import org.springframework.jdbc.core.RowMapper;
|
||||
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.transaction.PlatformTransactionManager;
|
||||
import org.springframework.transaction.TransactionDefinition;
|
||||
import org.springframework.transaction.TransactionException;
|
||||
import org.springframework.transaction.TransactionStatus;
|
||||
import org.springframework.transaction.support.AbstractPlatformTransactionManager;
|
||||
import org.springframework.transaction.support.DefaultTransactionStatus;
|
||||
import org.springframework.transaction.support.TransactionTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* An implementation of {@link ChatMemoryRepository} for JDBC.
|
||||
@@ -56,7 +55,7 @@ import org.slf4j.LoggerFactory;
|
||||
* @author Mark Pollack
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
private final JdbcTemplate jdbcTemplate;
|
||||
|
||||
@@ -78,7 +77,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
@Override
|
||||
public List<String> findConversationIds() {
|
||||
List<String> conversationIds = this.jdbcTemplate.query(dialect.getSelectConversationIdsSql(), rs -> {
|
||||
List<String> conversationIds = this.jdbcTemplate.query(this.dialect.getSelectConversationIdsSql(), rs -> {
|
||||
var ids = new ArrayList<String>();
|
||||
while (rs.next()) {
|
||||
ids.add(rs.getString(1));
|
||||
@@ -91,7 +90,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
@Override
|
||||
public List<Message> findByConversationId(String conversationId) {
|
||||
Assert.hasText(conversationId, "conversationId cannot be null or empty");
|
||||
return this.jdbcTemplate.query(dialect.getSelectMessagesSql(), new MessageRowMapper(), conversationId);
|
||||
return this.jdbcTemplate.query(this.dialect.getSelectMessagesSql(), new MessageRowMapper(), conversationId);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -100,9 +99,9 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
Assert.notNull(messages, "messages cannot be null");
|
||||
Assert.noNullElements(messages, "messages cannot contain null elements");
|
||||
|
||||
transactionTemplate.execute(status -> {
|
||||
this.transactionTemplate.execute(status -> {
|
||||
deleteByConversationId(conversationId);
|
||||
jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(),
|
||||
this.jdbcTemplate.batchUpdate(this.dialect.getInsertMessageSql(),
|
||||
new AddBatchPreparedStatement(conversationId, messages));
|
||||
return null;
|
||||
});
|
||||
@@ -111,7 +110,11 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
@Override
|
||||
public void deleteByConversationId(String conversationId) {
|
||||
Assert.hasText(conversationId, "conversationId cannot be null or empty");
|
||||
this.jdbcTemplate.update(dialect.getDeleteMessagesSql(), conversationId);
|
||||
this.jdbcTemplate.update(this.dialect.getDeleteMessagesSql(), conversationId);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
|
||||
@@ -128,7 +131,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
ps.setString(1, this.conversationId);
|
||||
ps.setString(2, message.getText());
|
||||
ps.setString(3, message.getMessageType().name());
|
||||
ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement()));
|
||||
ps.setTimestamp(4, new Timestamp(this.instantSeq.getAndIncrement()));
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -158,11 +161,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
public static final class Builder {
|
||||
|
||||
private JdbcTemplate jdbcTemplate;
|
||||
|
||||
|
||||
@@ -55,16 +55,21 @@ public interface JdbcChatMemoryRepositoryDialect {
|
||||
// Simple detection (could be improved)
|
||||
try {
|
||||
String url = dataSource.getConnection().getMetaData().getURL().toLowerCase();
|
||||
if (url.contains("postgresql"))
|
||||
if (url.contains("postgresql")) {
|
||||
return new PostgresChatMemoryRepositoryDialect();
|
||||
if (url.contains("mysql"))
|
||||
}
|
||||
if (url.contains("mysql")) {
|
||||
return new MysqlChatMemoryRepositoryDialect();
|
||||
if (url.contains("mariadb"))
|
||||
}
|
||||
if (url.contains("mariadb")) {
|
||||
return new MysqlChatMemoryRepositoryDialect();
|
||||
if (url.contains("sqlserver"))
|
||||
}
|
||||
if (url.contains("sqlserver")) {
|
||||
return new SqlServerChatMemoryRepositoryDialect();
|
||||
if (url.contains("hsqldb"))
|
||||
}
|
||||
if (url.contains("hsqldb")) {
|
||||
return new HsqldbChatMemoryRepositoryDialect();
|
||||
}
|
||||
// Add more as needed
|
||||
}
|
||||
catch (Exception ignored) {
|
||||
|
||||
@@ -16,31 +16,30 @@
|
||||
|
||||
package org.springframework.ai.chat.memory.repository.jdbc;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
@@ -58,7 +57,7 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
|
||||
@Test
|
||||
void correctChatMemoryRepositoryInstance() {
|
||||
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
|
||||
assertThat(this.chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@@ -72,13 +71,14 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
case TOOL -> throw new IllegalArgumentException("TOOL message type not supported in this test");
|
||||
};
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.of(message));
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.of(message));
|
||||
|
||||
// Use dialect to get the appropriate SQL query
|
||||
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
|
||||
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
|
||||
.from(this.jdbcTemplate.getDataSource());
|
||||
String selectSql = dialect.getSelectMessagesSql()
|
||||
.replace("content, type", "conversation_id, content, type, timestamp");
|
||||
var result = jdbcTemplate.queryForMap(selectSql, conversationId);
|
||||
var result = this.jdbcTemplate.queryForMap(selectSql, conversationId);
|
||||
|
||||
assertThat(result.size()).isEqualTo(4);
|
||||
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
|
||||
@@ -94,13 +94,14 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
new UserMessage("Message from user - " + conversationId),
|
||||
new SystemMessage("Message from system - " + conversationId));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
this.chatMemoryRepository.saveAll(conversationId, messages);
|
||||
|
||||
// Use dialect to get the appropriate SQL query
|
||||
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
|
||||
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
|
||||
.from(this.jdbcTemplate.getDataSource());
|
||||
String selectSql = dialect.getSelectMessagesSql()
|
||||
.replace("content, type", "conversation_id, content, type, timestamp");
|
||||
var results = jdbcTemplate.queryForList(selectSql, conversationId);
|
||||
var results = this.jdbcTemplate.queryForList(selectSql, conversationId);
|
||||
|
||||
assertThat(results).hasSize(messages.size());
|
||||
|
||||
@@ -114,12 +115,12 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
|
||||
}
|
||||
|
||||
var count = chatMemoryRepository.findByConversationId(conversationId).size();
|
||||
var count = this.chatMemoryRepository.findByConversationId(conversationId).size();
|
||||
assertThat(count).isEqualTo(messages.size());
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello")));
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello")));
|
||||
|
||||
count = chatMemoryRepository.findByConversationId(conversationId).size();
|
||||
count = this.chatMemoryRepository.findByConversationId(conversationId).size();
|
||||
assertThat(count).isEqualTo(1);
|
||||
}
|
||||
|
||||
@@ -131,9 +132,9 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
new UserMessage("Message from user - " + conversationId),
|
||||
new SystemMessage("Message from system - " + conversationId));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
this.chatMemoryRepository.saveAll(conversationId, messages);
|
||||
|
||||
var results = chatMemoryRepository.findByConversationId(conversationId);
|
||||
var results = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(results.size()).isEqualTo(messages.size());
|
||||
assertThat(results).isEqualTo(messages);
|
||||
@@ -146,12 +147,12 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
new UserMessage("Message from user - " + conversationId),
|
||||
new SystemMessage("Message from system - " + conversationId));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
this.chatMemoryRepository.saveAll(conversationId, messages);
|
||||
|
||||
chatMemoryRepository.deleteByConversationId(conversationId);
|
||||
this.chatMemoryRepository.deleteByConversationId(conversationId);
|
||||
|
||||
var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?",
|
||||
Integer.class, conversationId);
|
||||
var count = this.jdbcTemplate.queryForObject(
|
||||
"SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?", Integer.class, conversationId);
|
||||
|
||||
assertThat(count).isZero();
|
||||
}
|
||||
@@ -160,8 +161,8 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
|
||||
void testMessageOrder() {
|
||||
// Create a repository using the from method to detect the dialect
|
||||
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
|
||||
.jdbcTemplate(jdbcTemplate)
|
||||
.dialect(JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource()))
|
||||
.jdbcTemplate(this.jdbcTemplate)
|
||||
.dialect(JdbcChatMemoryRepositoryDialect.from(this.jdbcTemplate.getDataSource()))
|
||||
.build();
|
||||
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
|
||||
@@ -19,6 +19,7 @@ package org.springframework.ai.chat.memory.repository.jdbc;
|
||||
import java.sql.Connection;
|
||||
import java.sql.DatabaseMetaData;
|
||||
import java.sql.SQLException;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -18,6 +18,7 @@ package org.springframework.ai.chat.memory.repository.jdbc;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -53,7 +54,7 @@ class JdbcChatMemoryRepositoryPostgresqlIT extends AbstractJdbcChatMemoryReposit
|
||||
void repositoryWithExplicitTransactionManager() {
|
||||
// Get the repository with explicit transaction manager
|
||||
ChatMemoryRepository repositoryWithTxManager = TestConfiguration
|
||||
.chatMemoryRepositoryWithTransactionManager(jdbcTemplate, jdbcTemplate.getDataSource());
|
||||
.chatMemoryRepositoryWithTransactionManager(this.jdbcTemplate, this.jdbcTemplate.getDataSource());
|
||||
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
var messages = List.<Message>of(new AssistantMessage("Message with transaction manager - " + conversationId),
|
||||
|
||||
@@ -1,18 +1,45 @@
|
||||
/*
|
||||
* Copyright 2025-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.memory.repository.neo4j;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.neo4j.driver.Session;
|
||||
import org.neo4j.driver.Transaction;
|
||||
import org.neo4j.driver.TransactionContext;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.content.MediaContent;
|
||||
import org.springframework.util.MimeType;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* An implementation of {@link ChatMemoryRepository} for Neo4J
|
||||
*
|
||||
@@ -31,9 +58,9 @@ public final class Neo4jChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
@Override
|
||||
public List<String> findConversationIds() {
|
||||
return config.getDriver()
|
||||
return this.config.getDriver()
|
||||
.executableQuery("MATCH (conversation:$($sessionLabel)) RETURN conversation.id")
|
||||
.withParameters(Map.of("sessionLabel", config.getSessionLabel()))
|
||||
.withParameters(Map.of("sessionLabel", this.config.getSessionLabel()))
|
||||
.execute(Collectors.mapping(r -> r.get("conversation.id").asString(), Collectors.toList()));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
/*
|
||||
* Copyright 2025-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.memory.repository.neo4j;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.neo4j.driver.Driver;
|
||||
import org.neo4j.driver.GraphDatabase;
|
||||
import org.neo4j.driver.Session;
|
||||
import org.neo4j.driver.Result;
|
||||
import org.neo4j.driver.Session;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -28,8 +44,9 @@ class Neo4JChatMemoryRepositoryConfigIT {
|
||||
|
||||
@AfterAll
|
||||
static void closeDriver() {
|
||||
if (driver != null)
|
||||
if (driver != null) {
|
||||
driver.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -44,10 +61,12 @@ class Neo4JChatMemoryRepositoryConfigIT {
|
||||
while (result.hasNext()) {
|
||||
var record = result.next();
|
||||
String name = record.get("name").asString();
|
||||
if ("session_conversation_id_index".equals(name))
|
||||
if ("session_conversation_id_index".equals(name)) {
|
||||
sessionIndexFound = true;
|
||||
if ("message_index_index".equals(name))
|
||||
}
|
||||
if ("message_index_index".equals(name)) {
|
||||
messageIndexFound = true;
|
||||
}
|
||||
}
|
||||
// Then
|
||||
assertThat(sessionIndexFound).isTrue();
|
||||
|
||||
@@ -16,6 +16,14 @@
|
||||
|
||||
package org.springframework.ai.chat.memory.repository.neo4j;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -24,6 +32,11 @@ import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.neo4j.driver.Driver;
|
||||
import org.neo4j.driver.Result;
|
||||
import org.neo4j.driver.Session;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
@@ -34,18 +47,6 @@ import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.util.MimeType;
|
||||
import org.testcontainers.containers.Neo4jContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.utility.DockerImageName;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@@ -74,24 +75,24 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl());
|
||||
config = Neo4jChatMemoryRepositoryConfig.builder().withDriver(driver).build();
|
||||
chatMemoryRepository = new Neo4jChatMemoryRepository(config);
|
||||
this.driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl());
|
||||
this.config = Neo4jChatMemoryRepositoryConfig.builder().withDriver(this.driver).build();
|
||||
this.chatMemoryRepository = new Neo4jChatMemoryRepository(this.config);
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
// Clean up all data after each test
|
||||
try (Session session = driver.session()) {
|
||||
try (Session session = this.driver.session()) {
|
||||
session.run("MATCH (n) DETACH DELETE n");
|
||||
}
|
||||
driver.close();
|
||||
this.driver.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
void correctChatMemoryRepositoryInstance() {
|
||||
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
|
||||
assertThat(chatMemoryRepository).isInstanceOf(Neo4jChatMemoryRepository.class);
|
||||
assertThat(this.chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
|
||||
assertThat(this.chatMemoryRepository).isInstanceOf(Neo4jChatMemoryRepository.class);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@@ -101,8 +102,8 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
var conversationId = UUID.randomUUID().toString();
|
||||
Message message = createMessageByType(content + " - " + conversationId, messageType);
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(message));
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(message));
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
@@ -114,10 +115,10 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
}
|
||||
|
||||
// Verify directly in the database
|
||||
try (Session session = driver.session()) {
|
||||
try (Session session = this.driver.session()) {
|
||||
var result = session.run(
|
||||
"MATCH (s:%s {id:$conversationId})-[:HAS_MESSAGE]->(m:%s) RETURN count(m) as count"
|
||||
.formatted(config.getSessionLabel(), config.getMessageLabel()),
|
||||
.formatted(this.config.getSessionLabel(), this.config.getMessageLabel()),
|
||||
Map.of("conversationId", conversationId));
|
||||
assertThat(result.single().get("count").asLong()).isEqualTo(1);
|
||||
}
|
||||
@@ -131,8 +132,8 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
new SystemMessage("Message from system - " + conversationId),
|
||||
new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
this.chatMemoryRepository.saveAll(conversationId, messages);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(messages.size());
|
||||
|
||||
@@ -155,8 +156,8 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
messages.add(new UserMessage("Message " + i));
|
||||
}
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
this.chatMemoryRepository.saveAll(conversationId, messages);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(messages.size());
|
||||
|
||||
@@ -173,11 +174,14 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
var conversationId2 = UUID.randomUUID().toString();
|
||||
var conversationId3 = UUID.randomUUID().toString();
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId1, List.<Message>of(new UserMessage("Message for conversation 1")));
|
||||
chatMemoryRepository.saveAll(conversationId2, List.<Message>of(new UserMessage("Message for conversation 2")));
|
||||
chatMemoryRepository.saveAll(conversationId3, List.<Message>of(new UserMessage("Message for conversation 3")));
|
||||
this.chatMemoryRepository.saveAll(conversationId1,
|
||||
List.<Message>of(new UserMessage("Message for conversation 1")));
|
||||
this.chatMemoryRepository.saveAll(conversationId2,
|
||||
List.<Message>of(new UserMessage("Message for conversation 2")));
|
||||
this.chatMemoryRepository.saveAll(conversationId3,
|
||||
List.<Message>of(new UserMessage("Message for conversation 3")));
|
||||
|
||||
List<String> conversationIds = chatMemoryRepository.findConversationIds();
|
||||
List<String> conversationIds = this.chatMemoryRepository.findConversationIds();
|
||||
|
||||
assertThat(conversationIds).hasSize(3);
|
||||
assertThat(conversationIds).contains(conversationId1, conversationId2, conversationId3);
|
||||
@@ -189,22 +193,21 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
List<Message> messages = List.of(new AssistantMessage("Message from assistant"),
|
||||
new UserMessage("Message from user"), new SystemMessage("Message from system"));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, messages);
|
||||
this.chatMemoryRepository.saveAll(conversationId, messages);
|
||||
|
||||
// Verify messages were saved
|
||||
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
|
||||
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
|
||||
|
||||
// Delete the conversation
|
||||
chatMemoryRepository.deleteByConversationId(conversationId);
|
||||
this.chatMemoryRepository.deleteByConversationId(conversationId);
|
||||
|
||||
// Verify messages were deleted
|
||||
assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty();
|
||||
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEmpty();
|
||||
|
||||
// Verify directly in the database
|
||||
try (Session session = driver.session()) {
|
||||
var result = session.run(
|
||||
"MATCH (s:%s {id:$conversationId}) RETURN count(s) as count".formatted(config.getSessionLabel()),
|
||||
Map.of("conversationId", conversationId));
|
||||
try (Session session = this.driver.session()) {
|
||||
var result = session.run("MATCH (s:%s {id:$conversationId}) RETURN count(s) as count"
|
||||
.formatted(this.config.getSessionLabel()), Map.of("conversationId", conversationId));
|
||||
assertThat(result.single().get("count").asLong()).isZero();
|
||||
}
|
||||
}
|
||||
@@ -216,17 +219,17 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
// Save initial messages
|
||||
List<Message> initialMessages = List.of(new UserMessage("Initial message 1"),
|
||||
new UserMessage("Initial message 2"), new UserMessage("Initial message 3"));
|
||||
chatMemoryRepository.saveAll(conversationId, initialMessages);
|
||||
this.chatMemoryRepository.saveAll(conversationId, initialMessages);
|
||||
|
||||
// Verify initial messages were saved
|
||||
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
|
||||
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
|
||||
|
||||
// Replace with new messages
|
||||
List<Message> newMessages = List.of(new UserMessage("New message 1"), new UserMessage("New message 2"));
|
||||
chatMemoryRepository.saveAll(conversationId, newMessages);
|
||||
this.chatMemoryRepository.saveAll(conversationId, newMessages);
|
||||
|
||||
// Verify only new messages exist
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(2);
|
||||
assertThat(retrievedMessages.get(0).getText()).isEqualTo("New message 1");
|
||||
assertThat(retrievedMessages.get(1).getText()).isEqualTo("New message 2");
|
||||
@@ -246,9 +249,9 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
|
||||
UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build();
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(userMessageWithMedia));
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(userMessageWithMedia));
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
UserMessage retrievedMessage = (UserMessage) retrievedMessages.get(0);
|
||||
@@ -264,9 +267,9 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"),
|
||||
new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2")));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(assistantMessage));
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(assistantMessage));
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
AssistantMessage retrievedMessage = (AssistantMessage) retrievedMessages.get(0);
|
||||
@@ -283,9 +286,9 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
.of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")),
|
||||
Map.of("metadataKey", "metadataValue"));
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
|
||||
ToolResponseMessage retrievedMessage = (ToolResponseMessage) retrievedMessages.get(0);
|
||||
@@ -305,8 +308,8 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
.metadata(customMetadata)
|
||||
.build();
|
||||
|
||||
chatMemoryRepository.saveAll(conversationId, List.of(systemMessage));
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.of(systemMessage));
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
|
||||
assertThat(retrievedMessages).hasSize(1);
|
||||
Message retrievedMessage = retrievedMessages.get(0);
|
||||
@@ -333,29 +336,29 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
// 1. Setup: Create a conversation with some initial messages
|
||||
UserMessage initialMessage1 = new UserMessage("Initial message 1");
|
||||
AssistantMessage initialMessage2 = new AssistantMessage("Initial response 1");
|
||||
chatMemoryRepository.saveAll(conversationId, List.of(initialMessage1, initialMessage2));
|
||||
this.chatMemoryRepository.saveAll(conversationId, List.of(initialMessage1, initialMessage2));
|
||||
|
||||
// Verify initial messages are there
|
||||
List<Message> messagesAfterInitialSave = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> messagesAfterInitialSave = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(messagesAfterInitialSave).hasSize(2);
|
||||
|
||||
// 2. Action: Call saveAll with an empty list
|
||||
chatMemoryRepository.saveAll(conversationId, Collections.emptyList());
|
||||
this.chatMemoryRepository.saveAll(conversationId, Collections.emptyList());
|
||||
|
||||
// 3. Assertions:
|
||||
// a) No messages should be found for the conversationId
|
||||
List<Message> messagesAfterEmptySave = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> messagesAfterEmptySave = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(messagesAfterEmptySave).isEmpty();
|
||||
|
||||
// b) The conversationId itself should no longer be listed (because
|
||||
// deleteByConversationId removes the session node)
|
||||
List<String> conversationIds = chatMemoryRepository.findConversationIds();
|
||||
List<String> conversationIds = this.chatMemoryRepository.findConversationIds();
|
||||
assertThat(conversationIds).doesNotContain(conversationId);
|
||||
|
||||
// c) Verify directly in Neo4j that the conversation node is gone
|
||||
try (Session session = driver.session()) {
|
||||
try (Session session = this.driver.session()) {
|
||||
Result result = session.run(
|
||||
"MATCH (s:%s {id: $conversationId}) RETURN s".formatted(config.getSessionLabel()),
|
||||
"MATCH (s:%s {id: $conversationId}) RETURN s".formatted(this.config.getSessionLabel()),
|
||||
Map.of("conversationId", conversationId));
|
||||
assertThat(result.hasNext()).isFalse(); // No conversation node should exist
|
||||
}
|
||||
@@ -372,9 +375,9 @@ class Neo4jChatMemoryRepositoryIT {
|
||||
.build();
|
||||
|
||||
List<Message> messagesToSave = List.of(messageWithEmptyContent, messageWithEmptyMetadata);
|
||||
chatMemoryRepository.saveAll(conversationId, messagesToSave);
|
||||
this.chatMemoryRepository.saveAll(conversationId, messagesToSave);
|
||||
|
||||
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
|
||||
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
|
||||
assertThat(retrievedMessages).hasSize(2);
|
||||
|
||||
// Verify first message (empty content)
|
||||
|
||||
@@ -51,7 +51,6 @@ import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.support.UsageCalculator;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -70,6 +69,7 @@ import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.support.UsageCalculator;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.util.json.JsonParser;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
|
||||
@@ -62,7 +62,7 @@ import org.springframework.web.reactive.function.client.WebClient;
|
||||
* @author Claudio Silva Junior
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class AnthropicApi {
|
||||
public final class AnthropicApi {
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
|
||||
@@ -142,7 +142,7 @@ public class AnthropicApiIT {
|
||||
.maxTokens(1500)
|
||||
.stream(true)
|
||||
.temperature(0.8)
|
||||
.tools(tools)
|
||||
.tools(this.tools)
|
||||
.build();
|
||||
|
||||
List<ChatCompletionResponse> responses = this.anthropicApi.chatCompletionStream(chatCompletionRequest)
|
||||
|
||||
@@ -16,11 +16,13 @@
|
||||
|
||||
package org.springframework.ai.azure.openai;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ import software.amazon.awssdk.regions.Region;
|
||||
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
|
||||
import org.springframework.ai.bedrock.converse.RequiresAwsCredentials;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.ChatClient.StreamResponseSpec;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.content.Media;
|
||||
@@ -171,12 +170,6 @@ public class BedrockNovaChatClientIT {
|
||||
assertThat(response).contains("30", "10", "15");
|
||||
}
|
||||
|
||||
public record WeatherRequest(String location, String unit) {
|
||||
}
|
||||
|
||||
public record WeatherResponse(int temp, String unit) {
|
||||
}
|
||||
|
||||
// https://github.com/spring-projects/spring-ai/issues/1878
|
||||
@Test
|
||||
void toolAnnotationWeatherForecast() {
|
||||
@@ -214,15 +207,6 @@ public class BedrockNovaChatClientIT {
|
||||
assertThat(content).contains("20 degrees");
|
||||
}
|
||||
|
||||
public static class DummyWeatherForecastTools {
|
||||
|
||||
@Tool(description = "Get the current weather forecast in Amsterdam")
|
||||
String getCurrentDateTime() {
|
||||
return "Weather is hot and sunny with a temperature of 20 degrees";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// https://github.com/spring-projects/spring-ai/issues/1878
|
||||
@Test
|
||||
void supplierBasedToolCalling() {
|
||||
@@ -266,17 +250,6 @@ public class BedrockNovaChatClientIT {
|
||||
assertThat(content).contains("30.0");
|
||||
}
|
||||
|
||||
public static class WeatherService implements Supplier<WeatherService.Response> {
|
||||
|
||||
public record Response(double temp) {
|
||||
}
|
||||
|
||||
public Response get() {
|
||||
return new Response(30.0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
public static class Config {
|
||||
|
||||
@@ -295,4 +268,31 @@ public class BedrockNovaChatClientIT {
|
||||
|
||||
}
|
||||
|
||||
public record WeatherRequest(String location, String unit) {
|
||||
}
|
||||
|
||||
public record WeatherResponse(int temp, String unit) {
|
||||
}
|
||||
|
||||
public static class DummyWeatherForecastTools {
|
||||
|
||||
@Tool(description = "Get the current weather forecast in Amsterdam")
|
||||
String getCurrentDateTime() {
|
||||
return "Weather is hot and sunny with a temperature of 20 degrees";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class WeatherService implements Supplier<WeatherService.Response> {
|
||||
|
||||
public Response get() {
|
||||
return new Response(30.0);
|
||||
}
|
||||
|
||||
public record Response(double temp) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@ import java.io.UncheckedIOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
|
||||
@@ -20,6 +20,8 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -34,9 +36,6 @@ import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.Observation;
|
||||
|
||||
/**
|
||||
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
|
||||
* Bedrock Titan Embedding API. Titan Embedding supports text and image (encoded in
|
||||
|
||||
@@ -22,6 +22,7 @@ import java.util.Base64;
|
||||
import java.util.List;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.micrometer.observation.tck.TestObservationRegistry;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
@@ -40,8 +41,6 @@ import org.springframework.core.io.DefaultResourceLoader;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import io.micrometer.observation.tck.TestObservationRegistry;
|
||||
|
||||
@SpringBootTest
|
||||
@RequiresAwsCredentials
|
||||
class BedrockTitanEmbeddingModelIT {
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
package org.springframework.ai.deepseek;
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.content.Media;
|
||||
package org.springframework.ai.deepseek;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.content.Media;
|
||||
|
||||
public class DeepSeekAssistantMessage extends AssistantMessage {
|
||||
|
||||
private Boolean prefix;
|
||||
@@ -50,7 +66,7 @@ public class DeepSeekAssistantMessage extends AssistantMessage {
|
||||
}
|
||||
|
||||
public Boolean getPrefix() {
|
||||
return prefix;
|
||||
return this.prefix;
|
||||
}
|
||||
|
||||
public void setPrefix(Boolean prefix) {
|
||||
@@ -58,7 +74,7 @@ public class DeepSeekAssistantMessage extends AssistantMessage {
|
||||
}
|
||||
|
||||
public String getReasoningContent() {
|
||||
return reasoningContent;
|
||||
return this.reasoningContent;
|
||||
}
|
||||
|
||||
public void setReasoningContent(String reasoningContent) {
|
||||
|
||||
@@ -16,16 +16,32 @@
|
||||
|
||||
package org.springframework.ai.deepseek;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.metadata.*;
|
||||
import org.springframework.ai.chat.model.*;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.model.MessageAggregator;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationContext;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
|
||||
@@ -41,7 +57,11 @@ import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Too
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.deepseek.api.common.DeepSeekConstants;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.tool.*;
|
||||
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.support.UsageCalculator;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
@@ -49,12 +69,6 @@ import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal DeepSeek}
|
||||
@@ -248,12 +262,12 @@ public class DeepSeekChatModel implements ChatModel {
|
||||
}
|
||||
|
||||
// @formatter:off
|
||||
Map<String, Object> metadata = Map.of(
|
||||
"id", chatCompletion2.id(),
|
||||
"role", roleMap.getOrDefault(id, ""),
|
||||
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
|
||||
);
|
||||
// @formatter:on
|
||||
Map<String, Object> metadata = Map.of(
|
||||
"id", chatCompletion2.id(),
|
||||
"role", roleMap.getOrDefault(id, ""),
|
||||
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
|
||||
);
|
||||
// @formatter:on
|
||||
return buildGeneration(choice, metadata);
|
||||
}).toList();
|
||||
DeepSeekApi.Usage usage = chatCompletion2.usage();
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2024-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,12 +13,23 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
import org.springframework.ai.deepseek.api.ResponseFormat;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
@@ -26,8 +37,6 @@ import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Chat completions options for the DeepSeek chat API.
|
||||
* <a href="https://platform.deepseek.com/api-docs/api/create-chat-completion">DeepSeek
|
||||
@@ -323,7 +332,7 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions {
|
||||
return Objects.hash(this.model, this.frequencyPenalty, this.logprobs, this.topLogprobs,
|
||||
this.maxTokens, this.presencePenalty, this.responseFormat,
|
||||
this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
|
||||
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolContext);
|
||||
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolContext);
|
||||
}
|
||||
|
||||
|
||||
@@ -351,6 +360,28 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions {
|
||||
&& Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled);
|
||||
}
|
||||
|
||||
public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
|
||||
return DeepSeekChatOptions.builder()
|
||||
.model(fromOptions.getModel())
|
||||
.frequencyPenalty(fromOptions.getFrequencyPenalty())
|
||||
.logprobs(fromOptions.getLogprobs())
|
||||
.topLogprobs(fromOptions.getTopLogprobs())
|
||||
.maxTokens(fromOptions.getMaxTokens())
|
||||
.presencePenalty(fromOptions.getPresencePenalty())
|
||||
.responseFormat(fromOptions.getResponseFormat())
|
||||
.stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null)
|
||||
.temperature(fromOptions.getTemperature())
|
||||
.topP(fromOptions.getTopP())
|
||||
.tools(fromOptions.getTools())
|
||||
.toolChoice(fromOptions.getToolChoice())
|
||||
.toolCallbacks(
|
||||
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
|
||||
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
|
||||
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
|
||||
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
|
||||
.build();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
protected DeepSeekChatOptions options;
|
||||
@@ -472,26 +503,4 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions {
|
||||
|
||||
}
|
||||
|
||||
public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
|
||||
return DeepSeekChatOptions.builder()
|
||||
.model(fromOptions.getModel())
|
||||
.frequencyPenalty(fromOptions.getFrequencyPenalty())
|
||||
.logprobs(fromOptions.getLogprobs())
|
||||
.topLogprobs(fromOptions.getTopLogprobs())
|
||||
.maxTokens(fromOptions.getMaxTokens())
|
||||
.presencePenalty(fromOptions.getPresencePenalty())
|
||||
.responseFormat(fromOptions.getResponseFormat())
|
||||
.stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null)
|
||||
.temperature(fromOptions.getTemperature())
|
||||
.topP(fromOptions.getTopP())
|
||||
.tools(fromOptions.getTools())
|
||||
.toolChoice(fromOptions.getToolChoice())
|
||||
.toolCallbacks(
|
||||
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
|
||||
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
|
||||
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
|
||||
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,6 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek.aot;
|
||||
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
@@ -35,8 +36,9 @@ public class DeepSeekRuntimeHints implements RuntimeHintsRegistrar {
|
||||
@Override
|
||||
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
|
||||
var mcs = MemberCategory.values();
|
||||
for (var tr : findJsonAnnotatedClassesInPackage(DeepSeekApi.class))
|
||||
for (var tr : findJsonAnnotatedClassesInPackage(DeepSeekApi.class)) {
|
||||
hints.reflection().registerType(tr, mcs);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -47,10 +47,6 @@ import org.springframework.web.client.ResponseErrorHandler;
|
||||
import org.springframework.web.client.RestClient;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
|
||||
import static org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BASE_URL;
|
||||
import static org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BETA_PATH;
|
||||
import static org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_COMPLETIONS_PATH;
|
||||
|
||||
/**
|
||||
* Single class implementation of the DeepSeek Chat Completion API:
|
||||
* https://platform.deepseek.com/api-docs/api/create-chat-completion
|
||||
@@ -196,6 +192,19 @@ public class DeepSeekApi {
|
||||
.flatMap(mono -> mono);
|
||||
}
|
||||
|
||||
private String getEndpoint(ChatCompletionRequest request) {
|
||||
boolean isPrefix = request.messages.stream()
|
||||
.map(ChatCompletionMessage::prefix)
|
||||
.filter(Objects::nonNull)
|
||||
.anyMatch(prefix -> prefix);
|
||||
String endpointPrefix = isPrefix ? this.betaPrefixPath : "";
|
||||
return endpointPrefix + this.completionsPath;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* DeepSeek Chat Completion
|
||||
* <a href="https://api-docs.deepseek.com/quick_start/pricing">Models</a>
|
||||
@@ -226,12 +235,12 @@ public class DeepSeekApi {
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
return this.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return value;
|
||||
return this.value;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -506,10 +515,9 @@ public class DeepSeekApi {
|
||||
@JsonProperty("temperature") Double temperature,
|
||||
@JsonProperty("top_p") Double topP,
|
||||
@JsonProperty("logprobs") Boolean logprobs,
|
||||
@JsonProperty("top_logprobs") Integer topLogprobs,
|
||||
@JsonProperty("top_logprobs") Integer topLogprobs,
|
||||
@JsonProperty("tools") List<FunctionTool> tools,
|
||||
@JsonProperty("tool_choice") Object toolChoice)
|
||||
{
|
||||
@JsonProperty("tool_choice") Object toolChoice) {
|
||||
|
||||
|
||||
/**
|
||||
@@ -520,8 +528,8 @@ public class DeepSeekApi {
|
||||
* as they become available, with the stream terminated by a data: [DONE] message.
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
|
||||
this(messages, null, null, null, null, null,
|
||||
null, stream, null, null, null, null, null, null);
|
||||
this(messages, null, null, null, null, null,
|
||||
null, stream, null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -532,9 +540,9 @@ public class DeepSeekApi {
|
||||
* @param temperature What sampling temperature to use, between 0 and 1.
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
|
||||
this(messages, model, null,
|
||||
null, null, null, null, false, temperature, null,
|
||||
null, null, null,null);
|
||||
this(messages, model, null,
|
||||
null, null, null, null, false, temperature, null,
|
||||
null, null, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -547,9 +555,9 @@ public class DeepSeekApi {
|
||||
* as they become available, with the stream terminated by a data: [DONE] message.
|
||||
*/
|
||||
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
|
||||
this(messages, model, null,
|
||||
null, null, null, null, stream, temperature, null,
|
||||
null, null, null,null);
|
||||
this(messages, model, null,
|
||||
null, null, null, null, stream, temperature, null,
|
||||
null, null, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -600,8 +608,7 @@ public class DeepSeekApi {
|
||||
@JsonProperty("tool_calls")
|
||||
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
|
||||
@JsonProperty("prefix") Boolean prefix,
|
||||
@JsonProperty("reasoning_content") String reasoningContent
|
||||
) { // @formatter:on
|
||||
@JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on
|
||||
|
||||
/**
|
||||
* Create a chat completion message with the given content and role. All other
|
||||
@@ -898,30 +905,17 @@ public class DeepSeekApi {
|
||||
|
||||
}
|
||||
|
||||
private String getEndpoint(ChatCompletionRequest request) {
|
||||
boolean isPrefix = request.messages.stream()
|
||||
.map(ChatCompletionMessage::prefix)
|
||||
.filter(Objects::nonNull)
|
||||
.anyMatch(prefix -> prefix);
|
||||
String endpointPrefix = isPrefix ? betaPrefixPath : "";
|
||||
return endpointPrefix + completionsPath;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private String baseUrl = DEFAULT_BASE_URL;
|
||||
private String baseUrl = org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BASE_URL;
|
||||
|
||||
private ApiKey apiKey;
|
||||
|
||||
private MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
|
||||
|
||||
private String completionsPath = DEFAULT_COMPLETIONS_PATH;
|
||||
private String completionsPath = org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_COMPLETIONS_PATH;
|
||||
|
||||
private String betaPrefixPath = DEFAULT_BETA_PATH;
|
||||
private String betaPrefixPath = org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BETA_PATH;
|
||||
|
||||
private RestClient.Builder restClientBuilder = RestClient.builder();
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@
|
||||
|
||||
package org.springframework.ai.deepseek.api;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk.ChunkChoice;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason;
|
||||
@@ -25,9 +28,6 @@ import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Rol
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Helper class to support Streaming function calling. It can merge the streamed
|
||||
* ChatCompletionChunk in case of function calling message.
|
||||
|
||||
@@ -16,12 +16,12 @@
|
||||
|
||||
package org.springframework.ai.deepseek.api;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* An object specifying the format that the model must output. Setting to { "type":
|
||||
* "json_object" } enables JSON Output, which guarantees the message the model generates
|
||||
@@ -42,7 +42,7 @@ import java.util.Objects;
|
||||
*/
|
||||
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public class ResponseFormat {
|
||||
public final class ResponseFormat {
|
||||
|
||||
/**
|
||||
* Type Must be one of 'text', 'json_object'.
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,6 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek.api.common;
|
||||
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
@@ -20,7 +21,7 @@ import org.springframework.ai.observation.conventions.AiProvider;
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
public class DeepSeekConstants {
|
||||
public final class DeepSeekConstants {
|
||||
|
||||
public static final String DEFAULT_BASE_URL = "https://api.deepseek.com";
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,9 +13,11 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
|
||||
|
||||
@@ -16,15 +16,22 @@
|
||||
|
||||
package org.springframework.ai.deepseek;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.*;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.retry.TransientAiException;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
@@ -33,9 +40,6 @@ import org.springframework.retry.RetryContext;
|
||||
import org.springframework.retry.RetryListener;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.mockito.ArgumentMatchers.isA;
|
||||
@@ -65,7 +69,6 @@ public class DeepSeekRetryTests {
|
||||
.defaultOptions(DeepSeekChatOptions.builder().build())
|
||||
.retryTemplate(retryTemplate)
|
||||
.build();
|
||||
;
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,6 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek;
|
||||
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,15 +13,17 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek.aot;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
import org.springframework.aot.hint.RuntimeHints;
|
||||
import org.springframework.aot.hint.TypeReference;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
|
||||
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,16 +13,22 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek.api;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.*;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatModel;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@@ -37,7 +43,7 @@ public class DeepSeekApiIT {
|
||||
@Test
|
||||
void chatCompletionEntity() {
|
||||
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
|
||||
ResponseEntity<ChatCompletion> response = deepSeekApi.chatCompletionEntity(
|
||||
ResponseEntity<ChatCompletion> response = this.deepSeekApi.chatCompletionEntity(
|
||||
new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1D, false));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
@@ -47,7 +53,7 @@ public class DeepSeekApiIT {
|
||||
@Test
|
||||
void chatCompletionStream() {
|
||||
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
|
||||
Flux<ChatCompletionChunk> response = deepSeekApi.chatCompletionStream(
|
||||
Flux<ChatCompletionChunk> response = this.deepSeekApi.chatCompletionStream(
|
||||
new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1D, true));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
|
||||
@@ -16,14 +16,14 @@
|
||||
|
||||
package org.springframework.ai.deepseek.api;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonClassDescription;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
* @author Geng Rong
|
||||
*/
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,6 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek.chat;
|
||||
|
||||
import java.util.List;
|
||||
@@ -30,7 +31,7 @@ public class ActorsFilms {
|
||||
}
|
||||
|
||||
public String getActor() {
|
||||
return actor;
|
||||
return this.actor;
|
||||
}
|
||||
|
||||
public void setActor(String actor) {
|
||||
@@ -38,7 +39,7 @@ public class ActorsFilms {
|
||||
}
|
||||
|
||||
public List<String> getMovies() {
|
||||
return movies;
|
||||
return this.movies;
|
||||
}
|
||||
|
||||
public void setMovies(List<String> movies) {
|
||||
@@ -47,7 +48,7 @@ public class ActorsFilms {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}';
|
||||
return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}';
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,10 +16,18 @@
|
||||
|
||||
package org.springframework.ai.deepseek.chat;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
@@ -34,13 +42,6 @@ import org.springframework.ai.deepseek.api.MockWeatherService;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -13,12 +13,21 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.deepseek.chat;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
@@ -32,21 +41,16 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.converter.MapOutputConverter;
|
||||
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
|
||||
import org.springframework.ai.deepseek.DeepSeekChatOptions;
|
||||
import org.springframework.ai.deepseek.DeepSeekTestConfiguration;
|
||||
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
|
||||
import org.springframework.ai.deepseek.api.DeepSeekApi;
|
||||
import org.springframework.ai.deepseek.api.MockWeatherService;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.convert.support.DefaultConversionService;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
@@ -71,10 +75,10 @@ class DeepSeekChatModelIT {
|
||||
void roleTest() {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
|
||||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
ChatResponse response = this.chatModel.call(prompt);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
|
||||
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
|
||||
@@ -118,7 +122,7 @@ class DeepSeekChatModelIT {
|
||||
format))
|
||||
.build();
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -141,14 +145,11 @@ class DeepSeekChatModelIT {
|
||||
.variables(Map.of("format", format))
|
||||
.build();
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText());
|
||||
}
|
||||
|
||||
record ActorsFilmsRecord(String actor, List<String> movies) {
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanOutputConverterRecords() {
|
||||
|
||||
@@ -165,7 +166,7 @@ class DeepSeekChatModelIT {
|
||||
.variables(Map.of("format", format))
|
||||
.build();
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText());
|
||||
logger.info("" + actorsFilms);
|
||||
@@ -190,7 +191,7 @@ class DeepSeekChatModelIT {
|
||||
.build();
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = streamingChatModel.stream(prompt)
|
||||
String generationTextFromStream = this.streamingChatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -225,7 +226,7 @@ class DeepSeekChatModelIT {
|
||||
UserMessage userMessage = new UserMessage(userMessageContent);
|
||||
Message assistantMessage = new DeepSeekAssistantMessage("{\"code\":200,\"result\":{\"total\":1,\"data\":[1");
|
||||
Prompt prompt = new Prompt(List.of(userMessage, assistantMessage));
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
ChatResponse response = this.chatModel.call(prompt);
|
||||
assertThat(response.getResult().getOutput().getText().equals(",2,3]}}"));
|
||||
}
|
||||
|
||||
@@ -239,7 +240,7 @@ class DeepSeekChatModelIT {
|
||||
.model(DeepSeekApi.ChatModel.DEEPSEEK_REASONER.getValue())
|
||||
.build();
|
||||
Prompt prompt = new Prompt("9.11 and 9.8, which is greater?", promptOptions);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
ChatResponse response = this.chatModel.call(prompt);
|
||||
|
||||
DeepSeekAssistantMessage deepSeekAssistantMessage = (DeepSeekAssistantMessage) response.getResult().getOutput();
|
||||
assertThat(deepSeekAssistantMessage.getReasoningContent()).isNotEmpty();
|
||||
@@ -258,7 +259,7 @@ class DeepSeekChatModelIT {
|
||||
.build();
|
||||
|
||||
Prompt prompt = new Prompt(messages, promptOptions);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
ChatResponse response = this.chatModel.call(prompt);
|
||||
|
||||
DeepSeekAssistantMessage deepSeekAssistantMessage = (DeepSeekAssistantMessage) response.getResult().getOutput();
|
||||
assertThat(deepSeekAssistantMessage.getReasoningContent()).isNotEmpty();
|
||||
@@ -267,7 +268,7 @@ class DeepSeekChatModelIT {
|
||||
messages.add(new AssistantMessage(Objects.requireNonNull(deepSeekAssistantMessage.getText())));
|
||||
messages.add(new UserMessage("How many Rs are there in the word 'strawberry'?"));
|
||||
Prompt prompt2 = new Prompt(messages, promptOptions);
|
||||
ChatResponse response2 = chatModel.call(prompt2);
|
||||
ChatResponse response2 = this.chatModel.call(prompt2);
|
||||
|
||||
DeepSeekAssistantMessage deepSeekAssistantMessage2 = (DeepSeekAssistantMessage) response2.getResult()
|
||||
.getOutput();
|
||||
@@ -275,4 +276,7 @@ class DeepSeekChatModelIT {
|
||||
assertThat(deepSeekAssistantMessage2.getText()).isNotEmpty();
|
||||
}
|
||||
|
||||
record ActorsFilmsRecord(String actor, List<String> movies) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -16,11 +16,16 @@
|
||||
|
||||
package org.springframework.ai.deepseek.chat;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import io.micrometer.observation.tck.TestObservationRegistry;
|
||||
import io.micrometer.observation.tck.TestObservationRegistryAssert;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
|
||||
@@ -36,10 +41,6 @@ import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames;
|
||||
|
||||
@@ -40,7 +40,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApiConstants;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
@@ -39,7 +39,6 @@ import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.support.UsageCalculator;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -66,6 +65,7 @@ import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.support.UsageCalculator;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
|
||||
@@ -364,14 +364,14 @@ class MistralAiChatModelIT {
|
||||
|
||||
UserMessage userMessage1 = new UserMessage("My name is James Bond");
|
||||
memory.add(conversationId, userMessage1);
|
||||
ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
|
||||
assertThat(response1).isNotNull();
|
||||
memory.add(conversationId, response1.getResult().getOutput());
|
||||
|
||||
UserMessage userMessage2 = new UserMessage("What is my name?");
|
||||
memory.add(conversationId, userMessage2);
|
||||
ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
|
||||
assertThat(response2).isNotNull();
|
||||
memory.add(conversationId, response2.getResult().getOutput());
|
||||
@@ -396,7 +396,7 @@ class MistralAiChatModelIT {
|
||||
chatMemory.add(conversationId, prompt.getInstructions());
|
||||
|
||||
Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
|
||||
ChatResponse chatResponse = chatModel.call(promptWithMemory);
|
||||
ChatResponse chatResponse = this.chatModel.call(promptWithMemory);
|
||||
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
|
||||
|
||||
while (chatResponse.hasToolCalls()) {
|
||||
@@ -405,7 +405,7 @@ class MistralAiChatModelIT {
|
||||
chatMemory.add(conversationId, toolExecutionResult.conversationHistory()
|
||||
.get(toolExecutionResult.conversationHistory().size() - 1));
|
||||
promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
|
||||
chatResponse = chatModel.call(promptWithMemory);
|
||||
chatResponse = this.chatModel.call(promptWithMemory);
|
||||
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
|
||||
}
|
||||
|
||||
@@ -415,7 +415,7 @@ class MistralAiChatModelIT {
|
||||
UserMessage newUserMessage = new UserMessage("What did I ask you earlier?");
|
||||
chatMemory.add(conversationId, newUserMessage);
|
||||
|
||||
ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId)));
|
||||
ChatResponse newResponse = this.chatModel.call(new Prompt(chatMemory.get(conversationId)));
|
||||
|
||||
assertThat(newResponse).isNotNull();
|
||||
assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8");
|
||||
|
||||
@@ -53,7 +53,6 @@ import org.springframework.ai.chat.observation.DefaultChatModelObservationConven
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.ai.oci.ServingModeHelper;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
@@ -267,16 +267,19 @@ public class OCICohereChatOptions implements ChatOptions {
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(model, maxTokens, compartment, servingMode, preambleOverride, temperature, topP, topK, stop,
|
||||
frequencyPenalty, presencePenalty, documents, tools);
|
||||
return Objects.hash(this.model, this.maxTokens, this.compartment, this.servingMode, this.preambleOverride,
|
||||
this.temperature, this.topP, this.topK, this.stop, this.frequencyPenalty, this.presencePenalty,
|
||||
this.documents, this.tools);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
if (this == o) {
|
||||
return true;
|
||||
if (o == null || getClass() != o.getClass())
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
OCICohereChatOptions that = (OCICohereChatOptions) o;
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@
|
||||
|
||||
package org.springframework.ai.oci.cohere;
|
||||
|
||||
import com.oracle.bmc.generativeaiinference.model.CohereTool;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.oracle.bmc.generativeaiinference.model.CohereTool;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
|
||||
@@ -344,7 +344,8 @@ public class OllamaChatModel implements ChatModel {
|
||||
return Flux.just(ChatResponse.builder().from(response)
|
||||
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
|
||||
.build());
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
// Send the tool execution result back to the model.
|
||||
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
|
||||
response);
|
||||
|
||||
@@ -54,9 +54,11 @@ import org.springframework.web.reactive.function.client.WebClient;
|
||||
* @since 0.8.0
|
||||
*/
|
||||
// @formatter:off
|
||||
public class OllamaApi {
|
||||
public final class OllamaApi {
|
||||
|
||||
public static Builder builder() { return new Builder(); }
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null.";
|
||||
|
||||
|
||||
@@ -280,14 +280,14 @@ class OllamaChatModelIT extends BaseOllamaIT {
|
||||
|
||||
UserMessage userMessage1 = new UserMessage("My name is James Bond");
|
||||
memory.add(conversationId, userMessage1);
|
||||
ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
|
||||
assertThat(response1).isNotNull();
|
||||
memory.add(conversationId, response1.getResult().getOutput());
|
||||
|
||||
UserMessage userMessage2 = new UserMessage("What is my name?");
|
||||
memory.add(conversationId, userMessage2);
|
||||
ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
|
||||
assertThat(response2).isNotNull();
|
||||
memory.add(conversationId, response2.getResult().getOutput());
|
||||
@@ -312,7 +312,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
|
||||
chatMemory.add(conversationId, prompt.getInstructions());
|
||||
|
||||
Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
|
||||
ChatResponse chatResponse = chatModel.call(promptWithMemory);
|
||||
ChatResponse chatResponse = this.chatModel.call(promptWithMemory);
|
||||
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
|
||||
|
||||
while (chatResponse.hasToolCalls()) {
|
||||
@@ -321,7 +321,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
|
||||
chatMemory.add(conversationId, toolExecutionResult.conversationHistory()
|
||||
.get(toolExecutionResult.conversationHistory().size() - 1));
|
||||
promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
|
||||
chatResponse = chatModel.call(promptWithMemory);
|
||||
chatResponse = this.chatModel.call(promptWithMemory);
|
||||
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
|
||||
}
|
||||
|
||||
@@ -331,7 +331,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
|
||||
UserMessage newUserMessage = new UserMessage("What did I ask you earlier?");
|
||||
chatMemory.add(conversationId, newUserMessage);
|
||||
|
||||
ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId)));
|
||||
ChatResponse newResponse = this.chatModel.call(new Prompt(chatMemory.get(conversationId)));
|
||||
|
||||
assertThat(newResponse).isNotNull();
|
||||
assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8");
|
||||
|
||||
@@ -43,7 +43,6 @@ import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.support.UsageCalculator;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -74,6 +73,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
|
||||
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.support.UsageCalculator;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.core.io.ByteArrayResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
@@ -296,6 +296,31 @@ public class OpenAiApi {
|
||||
});
|
||||
}
|
||||
|
||||
// Package-private getters for mutate/copy
|
||||
String getBaseUrl() {
|
||||
return this.baseUrl;
|
||||
}
|
||||
|
||||
ApiKey getApiKey() {
|
||||
return this.apiKey;
|
||||
}
|
||||
|
||||
MultiValueMap<String, String> getHeaders() {
|
||||
return this.headers;
|
||||
}
|
||||
|
||||
String getCompletionsPath() {
|
||||
return this.completionsPath;
|
||||
}
|
||||
|
||||
String getEmbeddingsPath() {
|
||||
return this.embeddingsPath;
|
||||
}
|
||||
|
||||
ResponseErrorHandler getResponseErrorHandler() {
|
||||
return this.responseErrorHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI Chat Completion Models.
|
||||
* <p>
|
||||
@@ -1193,7 +1218,7 @@ public class OpenAiApi {
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record WebSearchOptions(@JsonProperty("search_context_size") SearchContextSize searchContextSize,
|
||||
@JsonProperty("user_location") UserLocation userLocation) {
|
||||
@JsonProperty("user_location") UserLocation userLocation) {
|
||||
|
||||
/**
|
||||
* High level guidance for the amount of context window space to use for the
|
||||
@@ -1229,11 +1254,11 @@ public class OpenAiApi {
|
||||
*/
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record UserLocation(@JsonProperty("type") String type,
|
||||
@JsonProperty("approximate") Approximate approximate) {
|
||||
@JsonProperty("approximate") Approximate approximate) {
|
||||
|
||||
@JsonInclude(Include.NON_NULL)
|
||||
public record Approximate(@JsonProperty("city") String city, @JsonProperty("country") String country,
|
||||
@JsonProperty("region") String region, @JsonProperty("timezone") String timezone) {
|
||||
@JsonProperty("region") String region, @JsonProperty("timezone") String timezone) {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1891,29 +1916,4 @@ public class OpenAiApi {
|
||||
|
||||
}
|
||||
|
||||
// Package-private getters for mutate/copy
|
||||
String getBaseUrl() {
|
||||
return this.baseUrl;
|
||||
}
|
||||
|
||||
ApiKey getApiKey() {
|
||||
return this.apiKey;
|
||||
}
|
||||
|
||||
MultiValueMap<String, String> getHeaders() {
|
||||
return this.headers;
|
||||
}
|
||||
|
||||
String getCompletionsPath() {
|
||||
return this.completionsPath;
|
||||
}
|
||||
|
||||
String getEmbeddingsPath() {
|
||||
return this.embeddingsPath;
|
||||
}
|
||||
|
||||
ResponseErrorHandler getResponseErrorHandler() {
|
||||
return this.responseErrorHandler;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -36,26 +36,29 @@ class OpenAiChatModelMutateTests {
|
||||
private final OpenAiApi baseApi = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("base-key").build();
|
||||
|
||||
private final OpenAiChatModel baseModel = OpenAiChatModel.builder()
|
||||
.openAiApi(baseApi)
|
||||
.openAiApi(this.baseApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("gpt-3.5-turbo").build())
|
||||
.build();
|
||||
|
||||
@Test
|
||||
void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
|
||||
// Mutate for GPT-4
|
||||
OpenAiApi gpt4Api = baseApi.mutate().baseUrl("https://api.openai.com").apiKey("your-api-key-for-gpt4").build();
|
||||
OpenAiChatModel gpt4Model = baseModel.mutate()
|
||||
OpenAiApi gpt4Api = this.baseApi.mutate()
|
||||
.baseUrl("https://api.openai.com")
|
||||
.apiKey("your-api-key-for-gpt4")
|
||||
.build();
|
||||
OpenAiChatModel gpt4Model = this.baseModel.mutate()
|
||||
.openAiApi(gpt4Api)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
|
||||
.build();
|
||||
ChatClient gpt4Client = ChatClient.builder(gpt4Model).build();
|
||||
|
||||
// Mutate for Llama
|
||||
OpenAiApi llamaApi = baseApi.mutate()
|
||||
OpenAiApi llamaApi = this.baseApi.mutate()
|
||||
.baseUrl("https://your-custom-endpoint.com")
|
||||
.apiKey("your-api-key-for-llama")
|
||||
.build();
|
||||
OpenAiChatModel llamaModel = baseModel.mutate()
|
||||
OpenAiChatModel llamaModel = this.baseModel.mutate()
|
||||
.openAiApi(llamaApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("llama-70b").temperature(0.5).build())
|
||||
.build();
|
||||
@@ -72,29 +75,29 @@ class OpenAiChatModelMutateTests {
|
||||
|
||||
@Test
|
||||
void testCloneCreatesDeepCopy() {
|
||||
OpenAiChatModel clone = baseModel.clone();
|
||||
assertThat(clone).isNotSameAs(baseModel);
|
||||
assertThat(clone.toString()).isEqualTo(baseModel.toString());
|
||||
OpenAiChatModel clone = this.baseModel.clone();
|
||||
assertThat(clone).isNotSameAs(this.baseModel);
|
||||
assertThat(clone.toString()).isEqualTo(this.baseModel.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
void mutateDoesNotAffectOriginal() {
|
||||
OpenAiChatModel mutated = baseModel.mutate()
|
||||
OpenAiChatModel mutated = this.baseModel.mutate()
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build())
|
||||
.build();
|
||||
assertThat(mutated).isNotSameAs(baseModel);
|
||||
assertThat(mutated).isNotSameAs(this.baseModel);
|
||||
assertThat(mutated.getDefaultOptions().getModel()).isEqualTo("gpt-4");
|
||||
assertThat(baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo");
|
||||
assertThat(this.baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo");
|
||||
}
|
||||
|
||||
@Test
|
||||
void mutateHeadersCreatesDistinctHeaders() {
|
||||
OpenAiApi mutatedApi = baseApi.mutate()
|
||||
OpenAiApi mutatedApi = this.baseApi.mutate()
|
||||
.headers(new LinkedMultiValueMap<>(java.util.Map.of("X-Test", java.util.List.of("value"))))
|
||||
.build();
|
||||
|
||||
assertThat(mutatedApi.getHeaders()).containsKey("X-Test");
|
||||
assertThat(baseApi.getHeaders()).doesNotContainKey("X-Test");
|
||||
assertThat(this.baseApi.getHeaders()).doesNotContainKey("X-Test");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -108,7 +111,9 @@ class OpenAiChatModelMutateTests {
|
||||
|
||||
@Test
|
||||
void multipleSequentialMutationsProduceDistinctInstances() {
|
||||
OpenAiChatModel m1 = baseModel.mutate().defaultOptions(OpenAiChatOptions.builder().model("m1").build()).build();
|
||||
OpenAiChatModel m1 = this.baseModel.mutate()
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("m1").build())
|
||||
.build();
|
||||
OpenAiChatModel m2 = m1.mutate().defaultOptions(OpenAiChatOptions.builder().model("m2").build()).build();
|
||||
OpenAiChatModel m3 = m2.mutate().defaultOptions(OpenAiChatOptions.builder().model("m3").build()).build();
|
||||
assertThat(m1).isNotSameAs(m2);
|
||||
@@ -120,8 +125,8 @@ class OpenAiChatModelMutateTests {
|
||||
|
||||
@Test
|
||||
void mutateAndCloneAreEquivalent() {
|
||||
OpenAiChatModel mutated = baseModel.mutate().build();
|
||||
OpenAiChatModel cloned = baseModel.clone();
|
||||
OpenAiChatModel mutated = this.baseModel.mutate().build();
|
||||
OpenAiChatModel cloned = this.baseModel.clone();
|
||||
assertThat(mutated.toString()).isEqualTo(cloned.toString());
|
||||
assertThat(mutated).isNotSameAs(cloned);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ package org.springframework.ai.openai.chat;
|
||||
|
||||
import java.net.MalformedURLException;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
@@ -643,14 +643,14 @@ public class OpenAiChatModelIT extends AbstractIT {
|
||||
|
||||
UserMessage userMessage1 = new UserMessage("My name is James Bond");
|
||||
memory.add(conversationId, userMessage1);
|
||||
ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
|
||||
assertThat(response1).isNotNull();
|
||||
memory.add(conversationId, response1.getResult().getOutput());
|
||||
|
||||
UserMessage userMessage2 = new UserMessage("What is my name?");
|
||||
memory.add(conversationId, userMessage2);
|
||||
ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId)));
|
||||
|
||||
assertThat(response2).isNotNull();
|
||||
memory.add(conversationId, response2.getResult().getOutput());
|
||||
@@ -675,7 +675,7 @@ public class OpenAiChatModelIT extends AbstractIT {
|
||||
chatMemory.add(conversationId, prompt.getInstructions());
|
||||
|
||||
Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
|
||||
ChatResponse chatResponse = chatModel.call(promptWithMemory);
|
||||
ChatResponse chatResponse = this.chatModel.call(promptWithMemory);
|
||||
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
|
||||
|
||||
while (chatResponse.hasToolCalls()) {
|
||||
@@ -684,7 +684,7 @@ public class OpenAiChatModelIT extends AbstractIT {
|
||||
chatMemory.add(conversationId, toolExecutionResult.conversationHistory()
|
||||
.get(toolExecutionResult.conversationHistory().size() - 1));
|
||||
promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
|
||||
chatResponse = chatModel.call(promptWithMemory);
|
||||
chatResponse = this.chatModel.call(promptWithMemory);
|
||||
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
|
||||
}
|
||||
|
||||
@@ -694,21 +694,12 @@ public class OpenAiChatModelIT extends AbstractIT {
|
||||
UserMessage newUserMessage = new UserMessage("What did I ask you earlier?");
|
||||
chatMemory.add(conversationId, newUserMessage);
|
||||
|
||||
ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId)));
|
||||
ChatResponse newResponse = this.chatModel.call(new Prompt(chatMemory.get(conversationId)));
|
||||
|
||||
assertThat(newResponse).isNotNull();
|
||||
assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8");
|
||||
}
|
||||
|
||||
static class MathTools {
|
||||
|
||||
@Tool(description = "Multiply the two numbers")
|
||||
double multiply(double a, double b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void webSearchAnnotationsTest() {
|
||||
UserMessage userMessage = new UserMessage("What is the latest news on the Mars rover?");
|
||||
@@ -779,4 +770,13 @@ public class OpenAiChatModelIT extends AbstractIT {
|
||||
|
||||
}
|
||||
|
||||
static class MathTools {
|
||||
|
||||
@Tool(description = "Multiply the two numbers")
|
||||
double multiply(double a, double b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -28,10 +28,10 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
/*
|
||||
* Copyright 2025-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.openai.chat.client;
|
||||
|
||||
import java.util.List;
|
||||
@@ -41,7 +57,7 @@ class OpenAiChatClientMemoryAdvisorReproIT {
|
||||
.build();
|
||||
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Act: call should succeed without exception (issue #2339 is fixed)
|
||||
chatClient.prompt(prompt).call().chatResponse(); // Should not throw
|
||||
|
||||
@@ -57,7 +57,7 @@ public class ReReadingAdvisor implements BaseAdvisor {
|
||||
@Override
|
||||
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
|
||||
String augmentedUserText = PromptTemplate.builder()
|
||||
.template(re2AdviseTemplate)
|
||||
.template(this.re2AdviseTemplate)
|
||||
.variables(Map.of("re2_input_query", chatClientRequest.prompt().getUserMessage().getText()))
|
||||
.build()
|
||||
.render();
|
||||
|
||||
@@ -81,7 +81,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
|
||||
|
||||
var advisor = createAdvisor(chatMemory);
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Create a prompt with multiple user messages
|
||||
List<Message> messages = new ArrayList<>();
|
||||
@@ -131,7 +131,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
|
||||
// Create advisor with the conversation ID
|
||||
var advisor = createAdvisor(chatMemory);
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Act - Create a list of messages for the prompt
|
||||
List<Message> messages = new ArrayList<>();
|
||||
@@ -193,7 +193,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
|
||||
// Create advisor without a default conversation ID
|
||||
var advisor = createAdvisor(chatMemory);
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
String question = "What is the capital of Germany?";
|
||||
|
||||
@@ -231,7 +231,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
|
||||
// Create advisor without a default conversation ID
|
||||
var advisor = createAdvisor(chatMemory);
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Act - First conversation
|
||||
String answer1 = chatClient.prompt()
|
||||
@@ -316,7 +316,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
|
||||
// Create advisor without a default conversation ID
|
||||
var advisor = createAdvisor(chatMemory);
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Act - Send a question to a non-existent conversation
|
||||
String question = "Do you remember our previous conversation?";
|
||||
@@ -375,7 +375,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
|
||||
|
||||
var advisor = createAdvisor(chatMemory);
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
List<String> responseList = new ArrayList<>();
|
||||
for (String message : List.of("My name is Charlie.", "I am 30 years old.", "I live in London.")) {
|
||||
|
||||
@@ -91,7 +91,7 @@ public class MessageChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
|
||||
.conversationId(conversationId)
|
||||
.build();
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Create a prompt with multiple user messages
|
||||
List<Message> messages = new ArrayList<>();
|
||||
@@ -151,7 +151,7 @@ public class MessageChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
|
||||
.conversationId(conversationId)
|
||||
.build();
|
||||
|
||||
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
|
||||
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
|
||||
|
||||
// Act - Use streaming API
|
||||
String userInput = "Tell me a short joke about programming";
|
||||
|
||||
@@ -50,25 +50,25 @@ class MultiOpenAiClientIT {
|
||||
@Test
|
||||
void multiClientFlow() {
|
||||
// Derive a new OpenAiApi for Groq (Llama3)
|
||||
OpenAiApi groqApi = baseOpenAiApi.mutate()
|
||||
OpenAiApi groqApi = this.baseOpenAiApi.mutate()
|
||||
.baseUrl("https://api.groq.com/openai")
|
||||
.apiKey(System.getenv("GROQ_API_KEY"))
|
||||
.build();
|
||||
|
||||
// Derive a new OpenAiApi for OpenAI GPT-4
|
||||
OpenAiApi gpt4Api = baseOpenAiApi.mutate()
|
||||
OpenAiApi gpt4Api = this.baseOpenAiApi.mutate()
|
||||
.baseUrl("https://api.openai.com")
|
||||
.apiKey(System.getenv("OPENAI_API_KEY"))
|
||||
.build();
|
||||
|
||||
// Derive a new OpenAiChatModel for Groq
|
||||
OpenAiChatModel groqModel = baseChatModel.mutate()
|
||||
OpenAiChatModel groqModel = this.baseChatModel.mutate()
|
||||
.openAiApi(groqApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("llama3-70b-8192").temperature(0.5).build())
|
||||
.build();
|
||||
|
||||
// Derive a new OpenAiChatModel for GPT-4
|
||||
OpenAiChatModel gpt4Model = baseChatModel.mutate()
|
||||
OpenAiChatModel gpt4Model = this.baseChatModel.mutate()
|
||||
.openAiApi(gpt4Api)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
|
||||
.build();
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user