Fixing miscellaneous checkstyle errors

Enabling checkstyle by default in the project build

Signed-off-by: Soby Chacko <soby.chacko@broadcom.com>
This commit is contained in:
Soby Chacko
2025-05-14 20:34:27 -04:00
committed by Ilayaperumal Gopinathan
parent 31feb4319b
commit 368be3a04f
218 changed files with 1281 additions and 1080 deletions

View File

@@ -21,14 +21,14 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.messages.UserMessage;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;

View File

@@ -21,7 +21,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.util.Assert;
import reactor.core.scheduler.Scheduler;
import org.springframework.ai.chat.client.ChatClientRequest;
@@ -38,6 +37,7 @@ import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.util.Assert;
/**
* Memory is retrieved from a VectorStore added into the prompt's system text.
@@ -50,7 +50,7 @@ import org.springframework.ai.vectorstore.VectorStore;
* @author Mark Pollack
* @since 1.0.0
*/
public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
public final class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
public static final String TOP_K = "chat_memory_vector_store_top_k";
@@ -104,7 +104,7 @@ public class VectorStoreChatMemoryAdvisor implements BaseChatMemoryAdvisor {
@Override
public int getOrder() {
return order;
return this.order;
}
@Override

View File

@@ -1,7 +1,24 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor.vectorstore;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.ai.vectorstore.VectorStore;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

View File

@@ -109,33 +109,6 @@ public class McpClientCommonProperties {
*/
private Toolcallback toolcallback = new Toolcallback();
/**
* Represents a callback configuration for tools.
* <p>
* This record is used to encapsulate the configuration for enabling or disabling tool
* callbacks in the MCP client.
*
* @param enabled A boolean flag indicating whether the tool callback is enabled. If
* true, the tool callback is active; otherwise, it is disabled.
*/
public static class Toolcallback {
/**
* A boolean flag indicating whether the tool callback is enabled. If true, the
* tool callback is active; otherwise, it is disabled.
*/
private boolean enabled = true;
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
public boolean isEnabled() {
return this.enabled;
}
}
public boolean isEnabled() {
return this.enabled;
}
@@ -193,11 +166,38 @@ public class McpClientCommonProperties {
}
public Toolcallback getToolcallback() {
return toolcallback;
return this.toolcallback;
}
public void setToolcallback(Toolcallback toolcallback) {
this.toolcallback = toolcallback;
}
/**
* Represents a callback configuration for tools.
* <p>
* This record is used to encapsulate the configuration for enabling or disabling tool
* callbacks in the MCP client.
*
* @param enabled A boolean flag indicating whether the tool callback is enabled. If
* true, the tool callback is active; otherwise, it is disabled.
*/
public static class Toolcallback {
/**
* A boolean flag indicating whether the tool callback is enabled. If true, the
* tool callback is active; otherwise, it is disabled.
*/
private boolean enabled = true;
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
public boolean isEnabled() {
return this.enabled;
}
}
}

View File

@@ -16,13 +16,14 @@
package org.springframework.ai.mcp.client.autoconfigure.properties;
import java.time.Duration;
import org.junit.jupiter.api.Test;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Configuration;
import java.time.Duration;
import static org.assertj.core.api.Assertions.assertThat;
/**

View File

@@ -16,13 +16,14 @@
package org.springframework.ai.mcp.client.autoconfigure.properties;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Configuration;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
/**

View File

@@ -351,4 +351,4 @@ public class McpServerAutoConfiguration {
return serverBuilder.build();
}
}
}

View File

@@ -13,11 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.mcp.server.autoconfigure;
import static org.assertj.core.api.Assertions.assertThat;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.jackson.JacksonAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -25,8 +27,7 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.reactive.function.server.RouterFunction;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
import static org.assertj.core.api.Assertions.assertThat;
class McpWebFluxServerAutoConfigurationTests {
@@ -36,7 +37,7 @@ class McpWebFluxServerAutoConfigurationTests {
@Test
void shouldConfigureWebFluxTransportWithCustomObjectMapper() {
this.contextRunner.run((context) -> {
this.contextRunner.run(context -> {
assertThat(context).hasSingleBean(WebFluxSseServerTransportProvider.class);
assertThat(context).hasSingleBean(RouterFunction.class);
assertThat(context).hasSingleBean(McpServerProperties.class);
@@ -48,6 +49,7 @@ class McpWebFluxServerAutoConfigurationTests {
.isEnabled(com.fasterxml.jackson.databind.DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)).isFalse();
// Test with a JSON payload containing unknown fields
// CHECKSTYLE:OFF
String jsonWithUnknownField = """
{
"tools": ["tool1", "tool2"],
@@ -55,6 +57,7 @@ class McpWebFluxServerAutoConfigurationTests {
"unknownField": "value"
}
""";
// CHECKSTYLE:ON
// This should not throw an exception
TestMessage message = objectMapper.readValue(jsonWithUnknownField, TestMessage.class);
@@ -75,7 +78,7 @@ class McpWebFluxServerAutoConfigurationTests {
private String name;
public String getName() {
return name;
return this.name;
}
public void setName(String name) {

View File

@@ -20,6 +20,7 @@ import io.micrometer.observation.ObservationRegistry;
import io.micrometer.tracing.Tracer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
@@ -30,7 +31,11 @@ import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.*;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

View File

@@ -19,6 +19,7 @@ package org.springframework.ai.model.chat.client.autoconfigure;
import io.micrometer.tracing.Tracer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationHandler;
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;

View File

@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.repository.cassandra.autoconfig
import com.datastax.oss.driver.api.core.CqlSession;
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig;
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.cassandra.CassandraChatMemoryRepositoryConfig;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.cassandra.CassandraAutoConfiguration;

View File

@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.repository.jdbc.autoconfigure;
import javax.sql.DataSource;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepositoryDialect;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepositoryDialect;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;

View File

@@ -59,7 +59,7 @@ public class JdbcChatMemoryRepositoryProperties {
}
public String getPlatform() {
return platform;
return this.platform;
}
public void setPlatform(String platform) {

View File

@@ -74,11 +74,11 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
// Debug: Print current schemas and tables
try {
List<String> schemas = jdbcTemplate.queryForList("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA",
String.class);
List<String> schemas = this.jdbcTemplate
.queryForList("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA", String.class);
System.out.println("Available schemas: " + schemas);
List<String> tables = jdbcTemplate
List<String> tables = this.jdbcTemplate
.queryForList("SELECT TABLE_SCHEMA, TABLE_NAME FROM INFORMATION_SCHEMA.TABLES", String.class);
System.out.println("Available tables: " + tables);
}
@@ -89,22 +89,22 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
// Try a more direct approach with explicit SQL statements
try {
// Drop the table first if it exists to avoid any conflicts
jdbcTemplate.execute("DROP TABLE SPRING_AI_CHAT_MEMORY IF EXISTS");
this.jdbcTemplate.execute("DROP TABLE SPRING_AI_CHAT_MEMORY IF EXISTS");
System.out.println("Dropped existing table if it existed");
// Create the table with a simplified schema
jdbcTemplate.execute("CREATE TABLE SPRING_AI_CHAT_MEMORY (" + "conversation_id VARCHAR(36) NOT NULL, "
+ "content LONGVARCHAR NOT NULL, " + "type VARCHAR(10) NOT NULL, "
+ "timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)");
this.jdbcTemplate.execute("CREATE TABLE SPRING_AI_CHAT_MEMORY ("
+ "conversation_id VARCHAR(36) NOT NULL, " + "content LONGVARCHAR NOT NULL, "
+ "type VARCHAR(10) NOT NULL, " + "timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL)");
System.out.println("Created table with simplified schema");
// Create index
jdbcTemplate.execute(
this.jdbcTemplate.execute(
"CREATE INDEX SPRING_AI_CHAT_MEMORY_IDX ON SPRING_AI_CHAT_MEMORY(conversation_id, timestamp DESC)");
System.out.println("Created index");
// Verify table was created
boolean tableExists = jdbcTemplate.queryForObject(
boolean tableExists = this.jdbcTemplate.queryForObject(
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'SPRING_AI_CHAT_MEMORY'",
Integer.class) > 0;
System.out.println("Table SPRING_AI_CHAT_MEMORY exists after creation: " + tableExists);
@@ -125,7 +125,7 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
@Test
public void useAutoConfiguredChatMemoryWithJdbc() {
// Check that the custom schema initializer is present
assertThat(context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue();
assertThat(this.context.containsBean("jdbcChatMemoryScriptDatabaseInitializer")).isTrue();
// Debug: List all schema-hsqldb.sql resources on the classpath
try {
@@ -144,7 +144,7 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
// Verify the table exists by executing a direct query
try {
boolean tableExists = jdbcTemplate.queryForObject(
boolean tableExists = this.jdbcTemplate.queryForObject(
"SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'SPRING_AI_CHAT_MEMORY'",
Integer.class) > 0;
System.out.println("Table SPRING_AI_CHAT_MEMORY exists: " + tableExists);
@@ -157,10 +157,10 @@ public class JdbcChatMemoryRepositoryHsqldbAutoConfigurationIT {
}
// Now test the ChatMemory functionality
assertThat(context.getBean(org.springframework.ai.chat.memory.ChatMemory.class)).isNotNull();
assertThat(context.getBean(JdbcChatMemoryRepository.class)).isNotNull();
assertThat(this.context.getBean(org.springframework.ai.chat.memory.ChatMemory.class)).isNotNull();
assertThat(this.context.getBean(JdbcChatMemoryRepository.class)).isNotNull();
var chatMemory = context.getBean(org.springframework.ai.chat.memory.ChatMemory.class);
var chatMemory = this.context.getBean(org.springframework.ai.chat.memory.ChatMemory.class);
var conversationId = java.util.UUID.randomUUID().toString();
var userMessage = new UserMessage("Message from the user");

View File

@@ -55,12 +55,14 @@ class JdbcChatMemoryRepositoryPostgresqlAutoConfigurationIT {
@Test
void jdbcChatMemoryScriptDatabaseInitializer_shouldNotRunSchemaInit() {
// CHECKSTYLE:OFF
this.contextRunner.withPropertyValues("spring.ai.chat.memory.repository.jdbc.initialize-schema=never")
.run(context -> {
assertThat(context).doesNotHaveBean("jdbcChatMemoryScriptDatabaseInitializer");
// Optionally, check that the schema is not initialized (could check table
// absence if needed)
});
// CHECKSTYLE:ON
}
@Test

View File

@@ -1,6 +1,19 @@
/*
* Integration test for SQL Server using Testcontainers, following the same structure as the PostgreSQL test.
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.model.chat.memory.repository.jdbc.autoconfigure;
import java.time.Duration;
@@ -8,6 +21,10 @@ import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.MSSQLServerContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.repository.jdbc.JdbcChatMemoryRepository;
@@ -19,13 +36,12 @@ import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.testcontainers.containers.MSSQLServerContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import static org.assertj.core.api.Assertions.assertThat;
/*
* Integration test for SQL Server using Testcontainers, following the same structure as the PostgreSQL test.
*/
@Testcontainers
class JdbcChatMemoryRepositorySqlServerAutoConfigurationIT {
@@ -58,9 +74,7 @@ class JdbcChatMemoryRepositorySqlServerAutoConfigurationIT {
@Test
void jdbcChatMemoryScriptDatabaseInitializer_shouldNotRunSchemaInit() {
this.contextRunner.withPropertyValues("spring.ai.chat.memory.repository.jdbc.initialize-schema=never")
.run(context -> {
assertThat(context).doesNotHaveBean("jdbcChatMemoryScriptDatabaseInitializer");
});
.run(context -> assertThat(context).doesNotHaveBean("jdbcChatMemoryScriptDatabaseInitializer"));
}
@Test

View File

@@ -18,8 +18,8 @@ package org.springframework.ai.model.chat.memory.repository.neo4j.autoconfigure;
import org.neo4j.driver.Driver;
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepositoryConfig;
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepositoryConfig;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;

View File

@@ -23,14 +23,13 @@ import java.util.Map;
import java.util.UUID;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepository;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepositoryConfig;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;

View File

@@ -17,6 +17,7 @@
package org.springframework.ai.model.chat.memory.autoconfigure;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository;
@@ -40,7 +41,7 @@ class ChatMemoryAutoConfigurationTests {
@Test
void defaultConfiguration() {
contextRunner.run(context -> {
this.contextRunner.run(context -> {
assertThat(context).hasSingleBean(ChatMemoryRepository.class);
assertThat(context).hasSingleBean(ChatMemory.class);
});
@@ -48,7 +49,7 @@ class ChatMemoryAutoConfigurationTests {
@Test
void whenChatMemoryRepositoryExists() {
contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
this.contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
assertThat(context).hasSingleBean(ChatMemoryRepository.class);
assertThat(context).hasBean("customChatMemoryRepository");
assertThat(context).doesNotHaveBean("chatMemoryRepository");
@@ -57,7 +58,7 @@ class ChatMemoryAutoConfigurationTests {
@Test
void whenChatMemoryExists() {
contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
this.contextRunner.withUserConfiguration(CustomChatMemoryRepositoryConfiguration.class).run(context -> {
assertThat(context).hasSingleBean(ChatMemoryRepository.class);
assertThat(context).hasBean("customChatMemoryRepository");
assertThat(context).doesNotHaveBean("chatMemoryRepository");
@@ -71,7 +72,7 @@ class ChatMemoryAutoConfigurationTests {
@Bean
ChatMemoryRepository customChatMemoryRepository() {
return customChatMemoryRepository;
return this.customChatMemoryRepository;
}
}
@@ -83,7 +84,7 @@ class ChatMemoryAutoConfigurationTests {
@Bean
ChatMemory customChatMemory() {
return customChatMemory;
return this.customChatMemory;
}
}

View File

@@ -16,10 +16,13 @@
package org.springframework.ai.model.chat.observation.autoconfigure;
import java.util.List;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.tracing.Tracer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.model.ChatModel;
@@ -33,13 +36,15 @@ import org.springframework.ai.model.observation.ErrorLoggingObservationHandler;
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.*;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
/**
* Auto-configuration for Spring AI chat model observations.
*

View File

@@ -16,10 +16,13 @@
package org.springframework.ai.model.chat.observation.autoconfigure;
import java.util.List;
import io.micrometer.core.instrument.composite.CompositeMeterRegistry;
import io.micrometer.tracing.Tracer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.observation.ChatModelCompletionObservationHandler;
import org.springframework.ai.chat.observation.ChatModelMeterObservationHandler;
@@ -35,8 +38,6 @@ import org.springframework.boot.test.system.OutputCaptureExtension;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;

View File

@@ -19,12 +19,17 @@ package org.springframework.ai.model.image.observation.autoconfigure;
import io.micrometer.tracing.Tracer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelPromptContentObservationHandler;
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.*;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

View File

@@ -19,6 +19,7 @@ package org.springframework.ai.model.image.observation.autoconfigure;
import io.micrometer.tracing.Tracer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.ai.image.observation.ImageModelObservationContext;
import org.springframework.ai.image.observation.ImageModelPromptContentObservationHandler;
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;

View File

@@ -17,7 +17,6 @@
package org.springframework.ai.model.bedrock.titan.autoconfigure;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.observation.ObservationRegistry;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;

View File

@@ -17,10 +17,11 @@
package org.springframework.ai.model.deepseek.autoconfigure;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.model.SimpleApiKey;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.model.SimpleApiKey;
import org.springframework.ai.model.SpringAIModelProperties;
import org.springframework.ai.model.SpringAIModels;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;

View File

@@ -71,7 +71,7 @@ public class DeepSeekChatProperties extends DeepSeekParentProperties {
}
public String getCompletionsPath() {
return completionsPath;
return this.completionsPath;
}
public void setCompletionsPath(String completionsPath) {
@@ -79,7 +79,7 @@ public class DeepSeekChatProperties extends DeepSeekParentProperties {
}
public String getBetaPrefixPath() {
return betaPrefixPath;
return this.betaPrefixPath;
}
public void setBetaPrefixPath(String betaPrefixPath) {

View File

@@ -16,10 +16,15 @@
package org.springframework.ai.model.deepseek.autoconfigure;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
@@ -28,10 +33,6 @@ import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import reactor.core.publisher.Flux;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;

View File

@@ -17,6 +17,7 @@
package org.springframework.ai.model.deepseek.autoconfigure;
import org.junit.jupiter.api.Test;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;

View File

@@ -16,10 +16,16 @@
package org.springframework.ai.model.deepseek.autoconfigure.tool;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
@@ -36,11 +42,6 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfigura
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;

View File

@@ -16,10 +16,15 @@
package org.springframework.ai.model.deepseek.autoconfigure.tool;
import java.util.List;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
@@ -31,10 +36,6 @@ import org.springframework.ai.model.deepseek.autoconfigure.DeepSeekChatAutoConfi
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;

View File

@@ -16,10 +16,16 @@
package org.springframework.ai.model.deepseek.autoconfigure.tool;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
@@ -34,11 +40,6 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Description;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;

View File

@@ -16,14 +16,14 @@
package org.springframework.ai.model.deepseek.autoconfigure.tool;
import java.util.function.Function;
import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import java.util.function.Function;
/**
* Mock 3rd party weather service.
*

View File

@@ -54,7 +54,7 @@ public class OpenAiImageProperties extends OpenAiParentProperties {
}
public String getImagesPath() {
return imagesPath;
return this.imagesPath;
}
public void setImagesPath(String imagesPath) {

View File

@@ -16,9 +16,13 @@
package org.springframework.ai.model.tool.autoconfigure;
import java.util.ArrayList;
import java.util.List;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.tool.ToolCallback;
@@ -40,9 +44,6 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
import org.springframework.context.support.GenericApplicationContext;
import java.util.ArrayList;
import java.util.List;
/**
* Auto-configuration for common tool calling features of {@link ChatModel}.
*

View File

@@ -39,7 +39,7 @@ public class ToolCallingProperties {
private boolean includeContent = false;
public boolean isIncludeContent() {
return includeContent;
return this.includeContent;
}
public void setIncludeContent(boolean includeContent) {

View File

@@ -116,9 +116,7 @@ class ToolCallingAutoConfigurationTests {
void observationFilterDefault() {
new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class))
.withUserConfiguration(Config.class)
.run(context -> {
assertThat(context).doesNotHaveBean(ToolCallingContentObservationFilter.class);
});
.run(context -> assertThat(context).doesNotHaveBean(ToolCallingContentObservationFilter.class));
}
@Test
@@ -126,9 +124,7 @@ class ToolCallingAutoConfigurationTests {
new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class))
.withPropertyValues("spring.ai.tools.observations.include-content=true")
.withUserConfiguration(Config.class)
.run(context -> {
assertThat(context).hasSingleBean(ToolCallingContentObservationFilter.class);
});
.run(context -> assertThat(context).hasSingleBean(ToolCallingContentObservationFilter.class));
}
static class WeatherService {

View File

@@ -64,7 +64,7 @@ public class CosmosDBVectorStoreAutoConfiguration {
}
CosmosClientBuilder builder = new CosmosClientBuilder().endpoint(properties.getEndpoint())
.userAgentSuffix(agentSuffix);
.userAgentSuffix(this.agentSuffix);
if (properties.getKey() == null || properties.getKey().isEmpty()) {
builder.credential(new DefaultAzureCredentialBuilder().build());

View File

@@ -16,12 +16,12 @@
package org.springframework.ai.vectorstore.cosmosdb.autoconfigure;
import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
import java.util.Arrays;
import java.util.List;
import org.springframework.ai.vectorstore.properties.CommonVectorStoreProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
/**
* Configuration properties for CosmosDB Vector Store.
*

View File

@@ -39,7 +39,7 @@ public class ChromaVectorStoreProperties extends CommonVectorStoreProperties {
private String collectionName = ChromaApiConstants.DEFAULT_COLLECTION_NAME;
public String getTenantName() {
return tenantName;
return this.tenantName;
}
public void setTenantName(String tenantName) {
@@ -47,7 +47,7 @@ public class ChromaVectorStoreProperties extends CommonVectorStoreProperties {
}
public String getDatabaseName() {
return databaseName;
return this.databaseName;
}
public void setDatabaseName(String databaseName) {

View File

@@ -19,12 +19,17 @@ package org.springframework.ai.vectorstore.observation.autoconfigure;
import io.micrometer.tracing.Tracer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.*;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

View File

@@ -19,6 +19,7 @@ package org.springframework.ai.vectorstore.observation.autoconfigure;
import io.micrometer.tracing.Tracer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.ai.observation.TracingAwareLoggingObservationHandler;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreQueryResponseObservationHandler;

View File

@@ -52,7 +52,6 @@ import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.ssl.SslBundles;

View File

@@ -98,7 +98,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
}
public String getSslBundle() {
return sslBundle;
return this.sslBundle;
}
public void setSslBundle(String sslBundle) {
@@ -106,7 +106,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
}
public Duration getConnectionTimeout() {
return connectionTimeout;
return this.connectionTimeout;
}
public void setConnectionTimeout(Duration connectionTimeout) {
@@ -114,7 +114,7 @@ public class OpenSearchVectorStoreProperties extends CommonVectorStoreProperties
}
public Duration getReadTimeout() {
return readTimeout;
return this.readTimeout;
}
public void setReadTimeout(Duration readTimeout) {

View File

@@ -47,7 +47,7 @@ import org.springframework.util.Assert;
* @author Mick Semb Wever
* @since 1.0.0
*/
public class CassandraChatMemoryRepository implements ChatMemoryRepository {
public final class CassandraChatMemoryRepository implements ChatMemoryRepository {
public static final String CONVERSATION_TS = CassandraChatMemoryRepository.class.getSimpleName()
+ "_message_timestamp";
@@ -125,7 +125,7 @@ public class CassandraChatMemoryRepository implements ChatMemoryRepository {
Instant instant = Instant.now();
List<Object> primaryKeys = this.conf.primaryKeyTranslator.apply(conversationId);
BoundStatementBuilder builder = addStmt.boundStatementBuilder();
BoundStatementBuilder builder = this.addStmt.boundStatementBuilder();
for (int k = 0; k < primaryKeys.size(); ++k) {
CassandraChatMemoryRepositoryConfig.SchemaColumn keyColumn = this.conf.getPrimaryKeyColumn(k);

View File

@@ -34,8 +34,6 @@ import com.datastax.oss.driver.api.core.type.UserDefinedType;
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumn;
import com.datastax.oss.driver.api.querybuilder.schema.AlterTableAddColumnEnd;
import com.datastax.oss.driver.api.querybuilder.schema.CreateTable;
import com.datastax.oss.driver.api.querybuilder.schema.CreateTableStart;
import com.datastax.oss.driver.api.querybuilder.schema.CreateTableWithOptions;
@@ -140,13 +138,13 @@ public final class CassandraChatMemoryRepositoryConfig {
Preconditions.checkState(this.session.getMetadata()
.getKeyspace(this.schema.keyspace())
.get()
.getUserDefinedType(messageUDT)
.getUserDefinedType(this.messageUDT)
.isPresent(), "table %s does not exist");
UserDefinedType udt = this.session.getMetadata()
.getKeyspace(this.schema.keyspace())
.get()
.getUserDefinedType(messageUDT)
.getUserDefinedType(this.messageUDT)
.get();
Preconditions.checkState(udt.contains(this.messageUdtTimestampColumn), "field %s does not exist",
@@ -186,7 +184,7 @@ public final class CassandraChatMemoryRepositoryConfig {
String lastClusteringColumn = this.schema.clusteringKeys.get(this.schema.clusteringKeys.size() - 1).name();
CreateTableWithOptions createTableWithOptions = createTable
.withColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true)))
.withColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(this.messageUDT, true)))
.withClusteringOrder(lastClusteringColumn, ClusteringOrder.DESC)
// TODO replace w/ SchemaBuilder.unifiedCompactionStrategy() when
// available
@@ -201,11 +199,11 @@ public final class CassandraChatMemoryRepositoryConfig {
private void ensureMessageTypeExist() {
SimpleStatement stmt = SchemaBuilder.createType(messageUDT)
SimpleStatement stmt = SchemaBuilder.createType(this.messageUDT)
.ifNotExists()
.withField(messageUdtTimestampColumn, DataTypes.TIMESTAMP)
.withField(messageUdtTypeColumn, DataTypes.TEXT)
.withField(messageUdtContentColumn, DataTypes.TEXT)
.withField(this.messageUdtTimestampColumn, DataTypes.TIMESTAMP)
.withField(this.messageUdtTypeColumn, DataTypes.TEXT)
.withField(this.messageUdtContentColumn, DataTypes.TEXT)
.build();
this.session.execute(stmt.setKeyspace(this.schema.keyspace));
@@ -222,7 +220,7 @@ public final class CassandraChatMemoryRepositoryConfig {
if (tableMetadata.getColumn(this.messagesColumn).isEmpty()) {
SimpleStatement stmt = SchemaBuilder.alterTable(this.schema.keyspace(), this.schema.table())
.addColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(messageUDT, true)))
.addColumn(this.messagesColumn, DataTypes.frozenListOf(SchemaBuilder.udt(this.messageUDT, true)))
.build();
logger.debug("Executing {}", stmt.getQuery());

View File

@@ -26,7 +26,6 @@ import com.datastax.oss.driver.api.core.cql.ResultSet;
import com.datastax.oss.driver.api.core.data.UdtValue;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.testcontainers.cassandra.CassandraContainer;

View File

@@ -16,7 +16,10 @@
package org.springframework.ai.chat.memory.repository.jdbc;
import java.sql.*;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
@@ -24,6 +27,9 @@ import java.util.concurrent.atomic.AtomicLong;
import javax.sql.DataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
@@ -37,15 +43,8 @@ import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import org.springframework.lang.Nullable;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionException;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.AbstractPlatformTransactionManager;
import org.springframework.transaction.support.DefaultTransactionStatus;
import org.springframework.transaction.support.TransactionTemplate;
import org.springframework.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* An implementation of {@link ChatMemoryRepository} for JDBC.
@@ -56,7 +55,7 @@ import org.slf4j.LoggerFactory;
* @author Mark Pollack
* @since 1.0.0
*/
public class JdbcChatMemoryRepository implements ChatMemoryRepository {
public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
private final JdbcTemplate jdbcTemplate;
@@ -78,7 +77,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
@Override
public List<String> findConversationIds() {
List<String> conversationIds = this.jdbcTemplate.query(dialect.getSelectConversationIdsSql(), rs -> {
List<String> conversationIds = this.jdbcTemplate.query(this.dialect.getSelectConversationIdsSql(), rs -> {
var ids = new ArrayList<String>();
while (rs.next()) {
ids.add(rs.getString(1));
@@ -91,7 +90,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
@Override
public List<Message> findByConversationId(String conversationId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
return this.jdbcTemplate.query(dialect.getSelectMessagesSql(), new MessageRowMapper(), conversationId);
return this.jdbcTemplate.query(this.dialect.getSelectMessagesSql(), new MessageRowMapper(), conversationId);
}
@Override
@@ -100,9 +99,9 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
Assert.notNull(messages, "messages cannot be null");
Assert.noNullElements(messages, "messages cannot contain null elements");
transactionTemplate.execute(status -> {
this.transactionTemplate.execute(status -> {
deleteByConversationId(conversationId);
jdbcTemplate.batchUpdate(dialect.getInsertMessageSql(),
this.jdbcTemplate.batchUpdate(this.dialect.getInsertMessageSql(),
new AddBatchPreparedStatement(conversationId, messages));
return null;
});
@@ -111,7 +110,11 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
@Override
public void deleteByConversationId(String conversationId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
this.jdbcTemplate.update(dialect.getDeleteMessagesSql(), conversationId);
this.jdbcTemplate.update(this.dialect.getDeleteMessagesSql(), conversationId);
}
public static Builder builder() {
return new Builder();
}
private record AddBatchPreparedStatement(String conversationId, List<Message> messages,
@@ -128,7 +131,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
ps.setString(1, this.conversationId);
ps.setString(2, message.getText());
ps.setString(3, message.getMessageType().name());
ps.setTimestamp(4, new Timestamp(instantSeq.getAndIncrement()));
ps.setTimestamp(4, new Timestamp(this.instantSeq.getAndIncrement()));
}
@Override
@@ -158,11 +161,7 @@ public class JdbcChatMemoryRepository implements ChatMemoryRepository {
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
public static final class Builder {
private JdbcTemplate jdbcTemplate;

View File

@@ -55,16 +55,21 @@ public interface JdbcChatMemoryRepositoryDialect {
// Simple detection (could be improved)
try {
String url = dataSource.getConnection().getMetaData().getURL().toLowerCase();
if (url.contains("postgresql"))
if (url.contains("postgresql")) {
return new PostgresChatMemoryRepositoryDialect();
if (url.contains("mysql"))
}
if (url.contains("mysql")) {
return new MysqlChatMemoryRepositoryDialect();
if (url.contains("mariadb"))
}
if (url.contains("mariadb")) {
return new MysqlChatMemoryRepositoryDialect();
if (url.contains("sqlserver"))
}
if (url.contains("sqlserver")) {
return new SqlServerChatMemoryRepositoryDialect();
if (url.contains("hsqldb"))
}
if (url.contains("hsqldb")) {
return new HsqldbChatMemoryRepositoryDialect();
}
// Add more as needed
}
catch (Exception ignored) {

View File

@@ -16,31 +16,30 @@
package org.springframework.ai.chat.memory.repository.jdbc;
import java.sql.Timestamp;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.jdbc.core.JdbcTemplate;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import static org.assertj.core.api.Assertions.assertThat;
/**
@@ -58,7 +57,7 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
@Test
void correctChatMemoryRepositoryInstance() {
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
assertThat(this.chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
}
@ParameterizedTest
@@ -72,13 +71,14 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
case TOOL -> throw new IllegalArgumentException("TOOL message type not supported in this test");
};
chatMemoryRepository.saveAll(conversationId, List.of(message));
this.chatMemoryRepository.saveAll(conversationId, List.of(message));
// Use dialect to get the appropriate SQL query
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
.from(this.jdbcTemplate.getDataSource());
String selectSql = dialect.getSelectMessagesSql()
.replace("content, type", "conversation_id, content, type, timestamp");
var result = jdbcTemplate.queryForMap(selectSql, conversationId);
var result = this.jdbcTemplate.queryForMap(selectSql, conversationId);
assertThat(result.size()).isEqualTo(4);
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
@@ -94,13 +94,14 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
this.chatMemoryRepository.saveAll(conversationId, messages);
// Use dialect to get the appropriate SQL query
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource());
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
.from(this.jdbcTemplate.getDataSource());
String selectSql = dialect.getSelectMessagesSql()
.replace("content, type", "conversation_id, content, type, timestamp");
var results = jdbcTemplate.queryForList(selectSql, conversationId);
var results = this.jdbcTemplate.queryForList(selectSql, conversationId);
assertThat(results).hasSize(messages.size());
@@ -114,12 +115,12 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class);
}
var count = chatMemoryRepository.findByConversationId(conversationId).size();
var count = this.chatMemoryRepository.findByConversationId(conversationId).size();
assertThat(count).isEqualTo(messages.size());
chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello")));
this.chatMemoryRepository.saveAll(conversationId, List.of(new UserMessage("Hello")));
count = chatMemoryRepository.findByConversationId(conversationId).size();
count = this.chatMemoryRepository.findByConversationId(conversationId).size();
assertThat(count).isEqualTo(1);
}
@@ -131,9 +132,9 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
this.chatMemoryRepository.saveAll(conversationId, messages);
var results = chatMemoryRepository.findByConversationId(conversationId);
var results = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(results.size()).isEqualTo(messages.size());
assertThat(results).isEqualTo(messages);
@@ -146,12 +147,12 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
new UserMessage("Message from user - " + conversationId),
new SystemMessage("Message from system - " + conversationId));
chatMemoryRepository.saveAll(conversationId, messages);
this.chatMemoryRepository.saveAll(conversationId, messages);
chatMemoryRepository.deleteByConversationId(conversationId);
this.chatMemoryRepository.deleteByConversationId(conversationId);
var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?",
Integer.class, conversationId);
var count = this.jdbcTemplate.queryForObject(
"SELECT COUNT(*) FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ?", Integer.class, conversationId);
assertThat(count).isZero();
}
@@ -160,8 +161,8 @@ public abstract class AbstractJdbcChatMemoryRepositoryIT {
void testMessageOrder() {
// Create a repository using the from method to detect the dialect
JdbcChatMemoryRepository repository = JdbcChatMemoryRepository.builder()
.jdbcTemplate(jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(jdbcTemplate.getDataSource()))
.jdbcTemplate(this.jdbcTemplate)
.dialect(JdbcChatMemoryRepositoryDialect.from(this.jdbcTemplate.getDataSource()))
.build();
var conversationId = UUID.randomUUID().toString();

View File

@@ -19,6 +19,7 @@ package org.springframework.ai.chat.memory.repository.jdbc;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import javax.sql.DataSource;
import org.junit.jupiter.api.Test;

View File

@@ -18,6 +18,7 @@ package org.springframework.ai.chat.memory.repository.jdbc;
import java.util.List;
import java.util.UUID;
import javax.sql.DataSource;
import org.junit.jupiter.api.Test;
@@ -53,7 +54,7 @@ class JdbcChatMemoryRepositoryPostgresqlIT extends AbstractJdbcChatMemoryReposit
void repositoryWithExplicitTransactionManager() {
// Get the repository with explicit transaction manager
ChatMemoryRepository repositoryWithTxManager = TestConfiguration
.chatMemoryRepositoryWithTransactionManager(jdbcTemplate, jdbcTemplate.getDataSource());
.chatMemoryRepositoryWithTransactionManager(this.jdbcTemplate, this.jdbcTemplate.getDataSource());
var conversationId = UUID.randomUUID().toString();
var messages = List.<Message>of(new AssistantMessage("Message with transaction manager - " + conversationId),

View File

@@ -1,18 +1,45 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.memory.repository.neo4j;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.neo4j.driver.Session;
import org.neo4j.driver.Transaction;
import org.neo4j.driver.TransactionContext;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.content.Media;
import org.springframework.ai.content.MediaContent;
import org.springframework.util.MimeType;
import java.net.URI;
import java.util.*;
import java.util.stream.Collectors;
/**
* An implementation of {@link ChatMemoryRepository} for Neo4J
*
@@ -31,9 +58,9 @@ public final class Neo4jChatMemoryRepository implements ChatMemoryRepository {
@Override
public List<String> findConversationIds() {
return config.getDriver()
return this.config.getDriver()
.executableQuery("MATCH (conversation:$($sessionLabel)) RETURN conversation.id")
.withParameters(Map.of("sessionLabel", config.getSessionLabel()))
.withParameters(Map.of("sessionLabel", this.config.getSessionLabel()))
.execute(Collectors.mapping(r -> r.get("conversation.id").asString(), Collectors.toList()));
}

View File

@@ -1,12 +1,28 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.memory.repository.neo4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Session;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
@@ -28,8 +44,9 @@ class Neo4JChatMemoryRepositoryConfigIT {
@AfterAll
static void closeDriver() {
if (driver != null)
if (driver != null) {
driver.close();
}
}
@Test
@@ -44,10 +61,12 @@ class Neo4JChatMemoryRepositoryConfigIT {
while (result.hasNext()) {
var record = result.next();
String name = record.get("name").asString();
if ("session_conversation_id_index".equals(name))
if ("session_conversation_id_index".equals(name)) {
sessionIndexFound = true;
if ("message_index_index".equals(name))
}
if ("message_index_index".equals(name)) {
messageIndexFound = true;
}
}
// Then
assertThat(sessionIndexFound).isTrue();

View File

@@ -16,6 +16,14 @@
package org.springframework.ai.chat.memory.repository.neo4j;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -24,6 +32,11 @@ import org.junit.jupiter.params.provider.CsvSource;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
@@ -34,18 +47,6 @@ import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.content.Media;
import org.springframework.util.MimeType;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
@@ -74,24 +75,24 @@ class Neo4jChatMemoryRepositoryIT {
@BeforeEach
void setUp() {
driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl());
config = Neo4jChatMemoryRepositoryConfig.builder().withDriver(driver).build();
chatMemoryRepository = new Neo4jChatMemoryRepository(config);
this.driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl());
this.config = Neo4jChatMemoryRepositoryConfig.builder().withDriver(this.driver).build();
this.chatMemoryRepository = new Neo4jChatMemoryRepository(this.config);
}
@AfterEach
void tearDown() {
// Clean up all data after each test
try (Session session = driver.session()) {
try (Session session = this.driver.session()) {
session.run("MATCH (n) DETACH DELETE n");
}
driver.close();
this.driver.close();
}
@Test
void correctChatMemoryRepositoryInstance() {
assertThat(chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
assertThat(chatMemoryRepository).isInstanceOf(Neo4jChatMemoryRepository.class);
assertThat(this.chatMemoryRepository).isInstanceOf(ChatMemoryRepository.class);
assertThat(this.chatMemoryRepository).isInstanceOf(Neo4jChatMemoryRepository.class);
}
@ParameterizedTest
@@ -101,8 +102,8 @@ class Neo4jChatMemoryRepositoryIT {
var conversationId = UUID.randomUUID().toString();
Message message = createMessageByType(content + " - " + conversationId, messageType);
chatMemoryRepository.saveAll(conversationId, List.<Message>of(message));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(message));
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
@@ -114,10 +115,10 @@ class Neo4jChatMemoryRepositoryIT {
}
// Verify directly in the database
try (Session session = driver.session()) {
try (Session session = this.driver.session()) {
var result = session.run(
"MATCH (s:%s {id:$conversationId})-[:HAS_MESSAGE]->(m:%s) RETURN count(m) as count"
.formatted(config.getSessionLabel(), config.getMessageLabel()),
.formatted(this.config.getSessionLabel(), this.config.getMessageLabel()),
Map.of("conversationId", conversationId));
assertThat(result.single().get("count").asLong()).isEqualTo(1);
}
@@ -131,8 +132,8 @@ class Neo4jChatMemoryRepositoryIT {
new SystemMessage("Message from system - " + conversationId),
new ToolResponseMessage(List.of(new ToolResponse("id", "name", "responseData"))));
chatMemoryRepository.saveAll(conversationId, messages);
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
this.chatMemoryRepository.saveAll(conversationId, messages);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(messages.size());
@@ -155,8 +156,8 @@ class Neo4jChatMemoryRepositoryIT {
messages.add(new UserMessage("Message " + i));
}
chatMemoryRepository.saveAll(conversationId, messages);
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
this.chatMemoryRepository.saveAll(conversationId, messages);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(messages.size());
@@ -173,11 +174,14 @@ class Neo4jChatMemoryRepositoryIT {
var conversationId2 = UUID.randomUUID().toString();
var conversationId3 = UUID.randomUUID().toString();
chatMemoryRepository.saveAll(conversationId1, List.<Message>of(new UserMessage("Message for conversation 1")));
chatMemoryRepository.saveAll(conversationId2, List.<Message>of(new UserMessage("Message for conversation 2")));
chatMemoryRepository.saveAll(conversationId3, List.<Message>of(new UserMessage("Message for conversation 3")));
this.chatMemoryRepository.saveAll(conversationId1,
List.<Message>of(new UserMessage("Message for conversation 1")));
this.chatMemoryRepository.saveAll(conversationId2,
List.<Message>of(new UserMessage("Message for conversation 2")));
this.chatMemoryRepository.saveAll(conversationId3,
List.<Message>of(new UserMessage("Message for conversation 3")));
List<String> conversationIds = chatMemoryRepository.findConversationIds();
List<String> conversationIds = this.chatMemoryRepository.findConversationIds();
assertThat(conversationIds).hasSize(3);
assertThat(conversationIds).contains(conversationId1, conversationId2, conversationId3);
@@ -189,22 +193,21 @@ class Neo4jChatMemoryRepositoryIT {
List<Message> messages = List.of(new AssistantMessage("Message from assistant"),
new UserMessage("Message from user"), new SystemMessage("Message from system"));
chatMemoryRepository.saveAll(conversationId, messages);
this.chatMemoryRepository.saveAll(conversationId, messages);
// Verify messages were saved
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
// Delete the conversation
chatMemoryRepository.deleteByConversationId(conversationId);
this.chatMemoryRepository.deleteByConversationId(conversationId);
// Verify messages were deleted
assertThat(chatMemoryRepository.findByConversationId(conversationId)).isEmpty();
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).isEmpty();
// Verify directly in the database
try (Session session = driver.session()) {
var result = session.run(
"MATCH (s:%s {id:$conversationId}) RETURN count(s) as count".formatted(config.getSessionLabel()),
Map.of("conversationId", conversationId));
try (Session session = this.driver.session()) {
var result = session.run("MATCH (s:%s {id:$conversationId}) RETURN count(s) as count"
.formatted(this.config.getSessionLabel()), Map.of("conversationId", conversationId));
assertThat(result.single().get("count").asLong()).isZero();
}
}
@@ -216,17 +219,17 @@ class Neo4jChatMemoryRepositoryIT {
// Save initial messages
List<Message> initialMessages = List.of(new UserMessage("Initial message 1"),
new UserMessage("Initial message 2"), new UserMessage("Initial message 3"));
chatMemoryRepository.saveAll(conversationId, initialMessages);
this.chatMemoryRepository.saveAll(conversationId, initialMessages);
// Verify initial messages were saved
assertThat(chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
assertThat(this.chatMemoryRepository.findByConversationId(conversationId)).hasSize(3);
// Replace with new messages
List<Message> newMessages = List.of(new UserMessage("New message 1"), new UserMessage("New message 2"));
chatMemoryRepository.saveAll(conversationId, newMessages);
this.chatMemoryRepository.saveAll(conversationId, newMessages);
// Verify only new messages exist
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(2);
assertThat(retrievedMessages.get(0).getText()).isEqualTo("New message 1");
assertThat(retrievedMessages.get(1).getText()).isEqualTo("New message 2");
@@ -246,9 +249,9 @@ class Neo4jChatMemoryRepositoryIT {
UserMessage userMessageWithMedia = UserMessage.builder().text("Message with media").media(media).build();
chatMemoryRepository.saveAll(conversationId, List.<Message>of(userMessageWithMedia));
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(userMessageWithMedia));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
UserMessage retrievedMessage = (UserMessage) retrievedMessages.get(0);
@@ -264,9 +267,9 @@ class Neo4jChatMemoryRepositoryIT {
List.of(new AssistantMessage.ToolCall("id1", "type1", "name1", "arguments1"),
new AssistantMessage.ToolCall("id2", "type2", "name2", "arguments2")));
chatMemoryRepository.saveAll(conversationId, List.<Message>of(assistantMessage));
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(assistantMessage));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
AssistantMessage retrievedMessage = (AssistantMessage) retrievedMessages.get(0);
@@ -283,9 +286,9 @@ class Neo4jChatMemoryRepositoryIT {
.of(new ToolResponse("id1", "name1", "responseData1"), new ToolResponse("id2", "name2", "responseData2")),
Map.of("metadataKey", "metadataValue"));
chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));
this.chatMemoryRepository.saveAll(conversationId, List.<Message>of(toolResponseMessage));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
ToolResponseMessage retrievedMessage = (ToolResponseMessage) retrievedMessages.get(0);
@@ -305,8 +308,8 @@ class Neo4jChatMemoryRepositoryIT {
.metadata(customMetadata)
.build();
chatMemoryRepository.saveAll(conversationId, List.of(systemMessage));
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
this.chatMemoryRepository.saveAll(conversationId, List.of(systemMessage));
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(1);
Message retrievedMessage = retrievedMessages.get(0);
@@ -333,29 +336,29 @@ class Neo4jChatMemoryRepositoryIT {
// 1. Setup: Create a conversation with some initial messages
UserMessage initialMessage1 = new UserMessage("Initial message 1");
AssistantMessage initialMessage2 = new AssistantMessage("Initial response 1");
chatMemoryRepository.saveAll(conversationId, List.of(initialMessage1, initialMessage2));
this.chatMemoryRepository.saveAll(conversationId, List.of(initialMessage1, initialMessage2));
// Verify initial messages are there
List<Message> messagesAfterInitialSave = chatMemoryRepository.findByConversationId(conversationId);
List<Message> messagesAfterInitialSave = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(messagesAfterInitialSave).hasSize(2);
// 2. Action: Call saveAll with an empty list
chatMemoryRepository.saveAll(conversationId, Collections.emptyList());
this.chatMemoryRepository.saveAll(conversationId, Collections.emptyList());
// 3. Assertions:
// a) No messages should be found for the conversationId
List<Message> messagesAfterEmptySave = chatMemoryRepository.findByConversationId(conversationId);
List<Message> messagesAfterEmptySave = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(messagesAfterEmptySave).isEmpty();
// b) The conversationId itself should no longer be listed (because
// deleteByConversationId removes the session node)
List<String> conversationIds = chatMemoryRepository.findConversationIds();
List<String> conversationIds = this.chatMemoryRepository.findConversationIds();
assertThat(conversationIds).doesNotContain(conversationId);
// c) Verify directly in Neo4j that the conversation node is gone
try (Session session = driver.session()) {
try (Session session = this.driver.session()) {
Result result = session.run(
"MATCH (s:%s {id: $conversationId}) RETURN s".formatted(config.getSessionLabel()),
"MATCH (s:%s {id: $conversationId}) RETURN s".formatted(this.config.getSessionLabel()),
Map.of("conversationId", conversationId));
assertThat(result.hasNext()).isFalse(); // No conversation node should exist
}
@@ -372,9 +375,9 @@ class Neo4jChatMemoryRepositoryIT {
.build();
List<Message> messagesToSave = List.of(messageWithEmptyContent, messageWithEmptyMetadata);
chatMemoryRepository.saveAll(conversationId, messagesToSave);
this.chatMemoryRepository.saveAll(conversationId, messagesToSave);
List<Message> retrievedMessages = chatMemoryRepository.findByConversationId(conversationId);
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
assertThat(retrievedMessages).hasSize(2);
// Verify first message (empty content)

View File

@@ -51,7 +51,6 @@ import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -70,6 +69,7 @@ import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.http.ResponseEntity;

View File

@@ -62,7 +62,7 @@ import org.springframework.web.reactive.function.client.WebClient;
* @author Claudio Silva Junior
* @since 1.0.0
*/
public class AnthropicApi {
public final class AnthropicApi {
public static Builder builder() {
return new Builder();

View File

@@ -142,7 +142,7 @@ public class AnthropicApiIT {
.maxTokens(1500)
.stream(true)
.temperature(0.8)
.tools(tools)
.tools(this.tools)
.build();
List<ChatCompletionResponse> responses = this.anthropicApi.chatCompletionStream(chatCompletionRequest)

View File

@@ -16,11 +16,13 @@
package org.springframework.ai.azure.openai;
import java.util.Map;
import java.util.Objects;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Map;
import java.util.Objects;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.StringUtils;

View File

@@ -32,7 +32,6 @@ import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
import org.springframework.ai.bedrock.converse.RequiresAwsCredentials;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClient.StreamResponseSpec;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.content.Media;
@@ -171,12 +170,6 @@ public class BedrockNovaChatClientIT {
assertThat(response).contains("30", "10", "15");
}
public record WeatherRequest(String location, String unit) {
}
public record WeatherResponse(int temp, String unit) {
}
// https://github.com/spring-projects/spring-ai/issues/1878
@Test
void toolAnnotationWeatherForecast() {
@@ -214,15 +207,6 @@ public class BedrockNovaChatClientIT {
assertThat(content).contains("20 degrees");
}
public static class DummyWeatherForecastTools {
@Tool(description = "Get the current weather forecast in Amsterdam")
String getCurrentDateTime() {
return "Weather is hot and sunny with a temperature of 20 degrees";
}
}
// https://github.com/spring-projects/spring-ai/issues/1878
@Test
void supplierBasedToolCalling() {
@@ -266,17 +250,6 @@ public class BedrockNovaChatClientIT {
assertThat(content).contains("30.0");
}
public static class WeatherService implements Supplier<WeatherService.Response> {
public record Response(double temp) {
}
public Response get() {
return new Response(30.0);
}
}
@SpringBootConfiguration
public static class Config {
@@ -295,4 +268,31 @@ public class BedrockNovaChatClientIT {
}
public record WeatherRequest(String location, String unit) {
}
public record WeatherResponse(int temp, String unit) {
}
public static class DummyWeatherForecastTools {
@Tool(description = "Get the current weather forecast in Amsterdam")
String getCurrentDateTime() {
return "Weather is hot and sunny with a temperature of 20 degrees";
}
}
public static class WeatherService implements Supplier<WeatherService.Response> {
public Response get() {
return new Response(30.0);
}
public record Response(double temp) {
}
}
}

View File

@@ -22,7 +22,8 @@ import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;

View File

@@ -20,6 +20,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -34,9 +36,6 @@ import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.Observation;
/**
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
* Bedrock Titan Embedding API. Titan Embedding supports text and image (encoded in

View File

@@ -22,6 +22,7 @@ import java.util.Base64;
import java.util.List;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.observation.tck.TestObservationRegistry;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
@@ -40,8 +41,6 @@ import org.springframework.core.io.DefaultResourceLoader;
import static org.assertj.core.api.Assertions.assertThat;
import io.micrometer.observation.tck.TestObservationRegistry;
@SpringBootTest
@RequiresAwsCredentials
class BedrockTitanEmbeddingModelIT {

View File

@@ -1,12 +1,28 @@
package org.springframework.ai.deepseek;
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.content.Media;
package org.springframework.ai.deepseek;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.content.Media;
public class DeepSeekAssistantMessage extends AssistantMessage {
private Boolean prefix;
@@ -50,7 +66,7 @@ public class DeepSeekAssistantMessage extends AssistantMessage {
}
public Boolean getPrefix() {
return prefix;
return this.prefix;
}
public void setPrefix(Boolean prefix) {
@@ -58,7 +74,7 @@ public class DeepSeekAssistantMessage extends AssistantMessage {
}
public String getReasoningContent() {
return reasoningContent;
return this.reasoningContent;
}
public void setReasoningContent(String reasoningContent) {

View File

@@ -16,16 +16,32 @@
package org.springframework.ai.deepseek;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.metadata.*;
import org.springframework.ai.chat.model.*;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
@@ -41,7 +57,11 @@ import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Too
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
import org.springframework.ai.deepseek.api.common.DeepSeekConstants;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.*;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
@@ -49,12 +69,6 @@ import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal DeepSeek}
@@ -248,12 +262,12 @@ public class DeepSeekChatModel implements ChatModel {
}
// @formatter:off
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
);
// @formatter:on
Map<String, Object> metadata = Map.of(
"id", chatCompletion2.id(),
"role", roleMap.getOrDefault(id, ""),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : ""
);
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
DeepSeekApi.Usage usage = chatCompletion2.usage();

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,12 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.deepseek.api.ResponseFormat;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
@@ -26,8 +37,6 @@ import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import java.util.*;
/**
* Chat completions options for the DeepSeek chat API.
* <a href="https://platform.deepseek.com/api-docs/api/create-chat-completion">DeepSeek
@@ -323,7 +332,7 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions {
return Objects.hash(this.model, this.frequencyPenalty, this.logprobs, this.topLogprobs,
this.maxTokens, this.presencePenalty, this.responseFormat,
this.stop, this.temperature, this.topP, this.tools, this.toolChoice,
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolContext);
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolContext);
}
@@ -351,6 +360,28 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions {
&& Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled);
}
public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
return DeepSeekChatOptions.builder()
.model(fromOptions.getModel())
.frequencyPenalty(fromOptions.getFrequencyPenalty())
.logprobs(fromOptions.getLogprobs())
.topLogprobs(fromOptions.getTopLogprobs())
.maxTokens(fromOptions.getMaxTokens())
.presencePenalty(fromOptions.getPresencePenalty())
.responseFormat(fromOptions.getResponseFormat())
.stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null)
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.tools(fromOptions.getTools())
.toolChoice(fromOptions.getToolChoice())
.toolCallbacks(
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.build();
}
public static class Builder {
protected DeepSeekChatOptions options;
@@ -472,26 +503,4 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions {
}
public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) {
return DeepSeekChatOptions.builder()
.model(fromOptions.getModel())
.frequencyPenalty(fromOptions.getFrequencyPenalty())
.logprobs(fromOptions.getLogprobs())
.topLogprobs(fromOptions.getTopLogprobs())
.maxTokens(fromOptions.getMaxTokens())
.presencePenalty(fromOptions.getPresencePenalty())
.responseFormat(fromOptions.getResponseFormat())
.stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null)
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.tools(fromOptions.getTools())
.toolChoice(fromOptions.getToolChoice())
.toolCallbacks(
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.build();
}
}

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek.aot;
import org.springframework.ai.deepseek.api.DeepSeekApi;
@@ -35,8 +36,9 @@ public class DeepSeekRuntimeHints implements RuntimeHintsRegistrar {
@Override
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
var mcs = MemberCategory.values();
for (var tr : findJsonAnnotatedClassesInPackage(DeepSeekApi.class))
for (var tr : findJsonAnnotatedClassesInPackage(DeepSeekApi.class)) {
hints.reflection().registerType(tr, mcs);
}
}
}

View File

@@ -47,10 +47,6 @@ import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import static org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BASE_URL;
import static org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BETA_PATH;
import static org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_COMPLETIONS_PATH;
/**
* Single class implementation of the DeepSeek Chat Completion API:
* https://platform.deepseek.com/api-docs/api/create-chat-completion
@@ -196,6 +192,19 @@ public class DeepSeekApi {
.flatMap(mono -> mono);
}
private String getEndpoint(ChatCompletionRequest request) {
boolean isPrefix = request.messages.stream()
.map(ChatCompletionMessage::prefix)
.filter(Objects::nonNull)
.anyMatch(prefix -> prefix);
String endpointPrefix = isPrefix ? this.betaPrefixPath : "";
return endpointPrefix + this.completionsPath;
}
public static Builder builder() {
return new Builder();
}
/**
* DeepSeek Chat Completion
* <a href="https://api-docs.deepseek.com/quick_start/pricing">Models</a>
@@ -226,12 +235,12 @@ public class DeepSeekApi {
}
public String getValue() {
return value;
return this.value;
}
@Override
public String getName() {
return value;
return this.value;
}
}
@@ -506,10 +515,9 @@ public class DeepSeekApi {
@JsonProperty("temperature") Double temperature,
@JsonProperty("top_p") Double topP,
@JsonProperty("logprobs") Boolean logprobs,
@JsonProperty("top_logprobs") Integer topLogprobs,
@JsonProperty("top_logprobs") Integer topLogprobs,
@JsonProperty("tools") List<FunctionTool> tools,
@JsonProperty("tool_choice") Object toolChoice)
{
@JsonProperty("tool_choice") Object toolChoice) {
/**
@@ -520,8 +528,8 @@ public class DeepSeekApi {
* as they become available, with the stream terminated by a data: [DONE] message.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
this(messages, null, null, null, null, null,
null, stream, null, null, null, null, null, null);
this(messages, null, null, null, null, null,
null, stream, null, null, null, null, null, null);
}
/**
@@ -532,9 +540,9 @@ public class DeepSeekApi {
* @param temperature What sampling temperature to use, between 0 and 1.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature) {
this(messages, model, null,
null, null, null, null, false, temperature, null,
null, null, null,null);
this(messages, model, null,
null, null, null, null, false, temperature, null,
null, null, null, null);
}
/**
@@ -547,9 +555,9 @@ public class DeepSeekApi {
* as they become available, with the stream terminated by a data: [DONE] message.
*/
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Double temperature, boolean stream) {
this(messages, model, null,
null, null, null, null, stream, temperature, null,
null, null, null,null);
this(messages, model, null,
null, null, null, null, stream, temperature, null,
null, null, null, null);
}
/**
@@ -600,8 +608,7 @@ public class DeepSeekApi {
@JsonProperty("tool_calls")
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
@JsonProperty("prefix") Boolean prefix,
@JsonProperty("reasoning_content") String reasoningContent
) { // @formatter:on
@JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on
/**
* Create a chat completion message with the given content and role. All other
@@ -898,30 +905,17 @@ public class DeepSeekApi {
}
private String getEndpoint(ChatCompletionRequest request) {
boolean isPrefix = request.messages.stream()
.map(ChatCompletionMessage::prefix)
.filter(Objects::nonNull)
.anyMatch(prefix -> prefix);
String endpointPrefix = isPrefix ? betaPrefixPath : "";
return endpointPrefix + completionsPath;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String baseUrl = DEFAULT_BASE_URL;
private String baseUrl = org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BASE_URL;
private ApiKey apiKey;
private MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
private String completionsPath = DEFAULT_COMPLETIONS_PATH;
private String completionsPath = org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_COMPLETIONS_PATH;
private String betaPrefixPath = DEFAULT_BETA_PATH;
private String betaPrefixPath = org.springframework.ai.deepseek.api.common.DeepSeekConstants.DEFAULT_BETA_PATH;
private RestClient.Builder restClientBuilder = RestClient.builder();

View File

@@ -16,6 +16,9 @@
package org.springframework.ai.deepseek.api;
import java.util.ArrayList;
import java.util.List;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk.ChunkChoice;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason;
@@ -25,9 +28,6 @@ import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Rol
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.ToolCall;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
/**
* Helper class to support Streaming function calling. It can merge the streamed
* ChatCompletionChunk in case of function calling message.

View File

@@ -16,12 +16,12 @@
package org.springframework.ai.deepseek.api;
import java.util.Objects;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Objects;
/**
* An object specifying the format that the model must output. Setting to { "type":
* "json_object" } enables JSON Output, which guarantees the message the model generates
@@ -42,7 +42,7 @@ import java.util.Objects;
*/
@JsonInclude(Include.NON_NULL)
public class ResponseFormat {
public final class ResponseFormat {
/**
* Type Must be one of 'text', 'json_object'.

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek.api.common;
import org.springframework.ai.observation.conventions.AiProvider;
@@ -20,7 +21,7 @@ import org.springframework.ai.observation.conventions.AiProvider;
/**
* @author Geng Rong
*/
public class DeepSeekConstants {
public final class DeepSeekConstants {
public static final String DEFAULT_BASE_URL = "https://api.deepseek.com";

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,9 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.deepseek.api.DeepSeekApi;

View File

@@ -16,15 +16,22 @@
package org.springframework.ai.deepseek;
import java.util.List;
import java.util.Optional;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.deepseek.api.DeepSeekApi.*;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionFinishReason;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.retry.TransientAiException;
import org.springframework.http.ResponseEntity;
@@ -33,9 +40,6 @@ import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import java.util.List;
import java.util.Optional;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.isA;
@@ -65,7 +69,6 @@ public class DeepSeekRetryTests {
.defaultOptions(DeepSeekChatOptions.builder().build())
.retryTemplate(retryTemplate)
.build();
;
}
@Test

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek;
import org.springframework.ai.deepseek.api.DeepSeekApi;

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,15 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek.aot;
import java.util.Set;
import org.junit.jupiter.api.Test;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import java.util.Set;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.springframework.ai.aot.AiRuntimeHints.findJsonAnnotatedClassesInPackage;
import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection;

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,16 +13,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek.api;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.deepseek.api.DeepSeekApi.*;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
import org.springframework.http.ResponseEntity;
import reactor.core.publisher.Flux;
import java.util.List;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletion;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionChunk;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionMessage.Role;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatCompletionRequest;
import org.springframework.ai.deepseek.api.DeepSeekApi.ChatModel;
import org.springframework.http.ResponseEntity;
import static org.assertj.core.api.Assertions.assertThat;
@@ -37,7 +43,7 @@ public class DeepSeekApiIT {
@Test
void chatCompletionEntity() {
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
ResponseEntity<ChatCompletion> response = deepSeekApi.chatCompletionEntity(
ResponseEntity<ChatCompletion> response = this.deepSeekApi.chatCompletionEntity(
new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1D, false));
assertThat(response).isNotNull();
@@ -47,7 +53,7 @@ public class DeepSeekApiIT {
@Test
void chatCompletionStream() {
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
Flux<ChatCompletionChunk> response = deepSeekApi.chatCompletionStream(
Flux<ChatCompletionChunk> response = this.deepSeekApi.chatCompletionStream(
new ChatCompletionRequest(List.of(chatCompletionMessage), ChatModel.DEEPSEEK_CHAT.value, 1D, true));
assertThat(response).isNotNull();

View File

@@ -16,14 +16,14 @@
package org.springframework.ai.deepseek.api;
import java.util.function.Function;
import com.fasterxml.jackson.annotation.JsonClassDescription;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import java.util.function.Function;
/**
* @author Geng Rong
*/

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek.chat;
import java.util.List;
@@ -30,7 +31,7 @@ public class ActorsFilms {
}
public String getActor() {
return actor;
return this.actor;
}
public void setActor(String actor) {
@@ -38,7 +39,7 @@ public class ActorsFilms {
}
public List<String> getMovies() {
return movies;
return this.movies;
}
public void setMovies(List<String> movies) {
@@ -47,7 +48,7 @@ public class ActorsFilms {
@Override
public String toString() {
return "ActorsFilms{" + "actor='" + actor + '\'' + ", movies=" + movies + '}';
return "ActorsFilms{" + "actor='" + this.actor + '\'' + ", movies=" + this.movies + '}';
}
}

View File

@@ -16,10 +16,18 @@
package org.springframework.ai.deepseek.chat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
@@ -34,13 +42,6 @@ import org.springframework.ai.deepseek.api.MockWeatherService;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;

View File

@@ -1,11 +1,11 @@
/*
* Copyright 2023 - 2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,12 +13,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.deepseek.chat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
@@ -32,21 +41,16 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.DeepSeekTestConfiguration;
import org.springframework.ai.deepseek.DeepSeekAssistantMessage;
import org.springframework.ai.deepseek.api.DeepSeekApi;
import org.springframework.ai.deepseek.api.MockWeatherService;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.io.Resource;
import java.util.*;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
/**
@@ -71,10 +75,10 @@ class DeepSeekChatModelIT {
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatModel.call(prompt);
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
@@ -118,7 +122,7 @@ class DeepSeekChatModelIT {
format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatModel.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -141,14 +145,11 @@ class DeepSeekChatModelIT {
.variables(Map.of("format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatModel.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText());
}
record ActorsFilmsRecord(String actor, List<String> movies) {
}
@Test
void beanOutputConverterRecords() {
@@ -165,7 +166,7 @@ class DeepSeekChatModelIT {
.variables(Map.of("format", format))
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatModel.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText());
logger.info("" + actorsFilms);
@@ -190,7 +191,7 @@ class DeepSeekChatModelIT {
.build();
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = streamingChatModel.stream(prompt)
String generationTextFromStream = this.streamingChatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -225,7 +226,7 @@ class DeepSeekChatModelIT {
UserMessage userMessage = new UserMessage(userMessageContent);
Message assistantMessage = new DeepSeekAssistantMessage("{\"code\":200,\"result\":{\"total\":1,\"data\":[1");
Prompt prompt = new Prompt(List.of(userMessage, assistantMessage));
ChatResponse response = chatModel.call(prompt);
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResult().getOutput().getText().equals(",2,3]}}"));
}
@@ -239,7 +240,7 @@ class DeepSeekChatModelIT {
.model(DeepSeekApi.ChatModel.DEEPSEEK_REASONER.getValue())
.build();
Prompt prompt = new Prompt("9.11 and 9.8, which is greater?", promptOptions);
ChatResponse response = chatModel.call(prompt);
ChatResponse response = this.chatModel.call(prompt);
DeepSeekAssistantMessage deepSeekAssistantMessage = (DeepSeekAssistantMessage) response.getResult().getOutput();
assertThat(deepSeekAssistantMessage.getReasoningContent()).isNotEmpty();
@@ -258,7 +259,7 @@ class DeepSeekChatModelIT {
.build();
Prompt prompt = new Prompt(messages, promptOptions);
ChatResponse response = chatModel.call(prompt);
ChatResponse response = this.chatModel.call(prompt);
DeepSeekAssistantMessage deepSeekAssistantMessage = (DeepSeekAssistantMessage) response.getResult().getOutput();
assertThat(deepSeekAssistantMessage.getReasoningContent()).isNotEmpty();
@@ -267,7 +268,7 @@ class DeepSeekChatModelIT {
messages.add(new AssistantMessage(Objects.requireNonNull(deepSeekAssistantMessage.getText())));
messages.add(new UserMessage("How many Rs are there in the word 'strawberry'?"));
Prompt prompt2 = new Prompt(messages, promptOptions);
ChatResponse response2 = chatModel.call(prompt2);
ChatResponse response2 = this.chatModel.call(prompt2);
DeepSeekAssistantMessage deepSeekAssistantMessage2 = (DeepSeekAssistantMessage) response2.getResult()
.getOutput();
@@ -275,4 +276,7 @@ class DeepSeekChatModelIT {
assertThat(deepSeekAssistantMessage2.getText()).isNotEmpty();
}
record ActorsFilmsRecord(String actor, List<String> movies) {
}
}

View File

@@ -16,11 +16,16 @@
package org.springframework.ai.deepseek.chat;
import java.util.List;
import java.util.stream.Collectors;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
@@ -36,10 +41,6 @@ import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.retry.support.RetryTemplate;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.stream.Collectors;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.chat.observation.ChatModelObservationDocumentation.HighCardinalityKeyNames;

View File

@@ -40,7 +40,6 @@ import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.ai.minimax.api.MiniMaxApiConstants;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

View File

@@ -39,7 +39,6 @@ import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -66,6 +65,7 @@ import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;

View File

@@ -364,14 +364,14 @@ class MistralAiChatModelIT {
UserMessage userMessage1 = new UserMessage("My name is James Bond");
memory.add(conversationId, userMessage1);
ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId)));
ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId)));
assertThat(response1).isNotNull();
memory.add(conversationId, response1.getResult().getOutput());
UserMessage userMessage2 = new UserMessage("What is my name?");
memory.add(conversationId, userMessage2);
ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId)));
ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId)));
assertThat(response2).isNotNull();
memory.add(conversationId, response2.getResult().getOutput());
@@ -396,7 +396,7 @@ class MistralAiChatModelIT {
chatMemory.add(conversationId, prompt.getInstructions());
Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
ChatResponse chatResponse = chatModel.call(promptWithMemory);
ChatResponse chatResponse = this.chatModel.call(promptWithMemory);
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
while (chatResponse.hasToolCalls()) {
@@ -405,7 +405,7 @@ class MistralAiChatModelIT {
chatMemory.add(conversationId, toolExecutionResult.conversationHistory()
.get(toolExecutionResult.conversationHistory().size() - 1));
promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
chatResponse = chatModel.call(promptWithMemory);
chatResponse = this.chatModel.call(promptWithMemory);
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
}
@@ -415,7 +415,7 @@ class MistralAiChatModelIT {
UserMessage newUserMessage = new UserMessage("What did I ask you earlier?");
chatMemory.add(conversationId, newUserMessage);
ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId)));
ChatResponse newResponse = this.chatModel.call(new Prompt(chatMemory.get(conversationId)));
assertThat(newResponse).isNotNull();
assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8");

View File

@@ -53,7 +53,6 @@ import org.springframework.ai.chat.observation.DefaultChatModelObservationConven
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.oci.ServingModeHelper;
import org.springframework.util.Assert;

View File

@@ -267,16 +267,19 @@ public class OCICohereChatOptions implements ChatOptions {
@Override
public int hashCode() {
return Objects.hash(model, maxTokens, compartment, servingMode, preambleOverride, temperature, topP, topK, stop,
frequencyPenalty, presencePenalty, documents, tools);
return Objects.hash(this.model, this.maxTokens, this.compartment, this.servingMode, this.preambleOverride,
this.temperature, this.topP, this.topK, this.stop, this.frequencyPenalty, this.presencePenalty,
this.documents, this.tools);
}
@Override
public boolean equals(Object o) {
if (this == o)
if (this == o) {
return true;
if (o == null || getClass() != o.getClass())
}
if (o == null || getClass() != o.getClass()) {
return false;
}
OCICohereChatOptions that = (OCICohereChatOptions) o;

View File

@@ -16,12 +16,12 @@
package org.springframework.ai.oci.cohere;
import com.oracle.bmc.generativeaiinference.model.CohereTool;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Map;
import com.oracle.bmc.generativeaiinference.model.CohereTool;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
/**

View File

@@ -344,7 +344,8 @@ public class OllamaChatModel implements ChatModel {
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
} else {
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);

View File

@@ -54,9 +54,11 @@ import org.springframework.web.reactive.function.client.WebClient;
* @since 0.8.0
*/
// @formatter:off
public class OllamaApi {
public final class OllamaApi {
public static Builder builder() { return new Builder(); }
public static Builder builder() {
return new Builder();
}
public static final String REQUEST_BODY_NULL_ERROR = "The request body can not be null.";

View File

@@ -280,14 +280,14 @@ class OllamaChatModelIT extends BaseOllamaIT {
UserMessage userMessage1 = new UserMessage("My name is James Bond");
memory.add(conversationId, userMessage1);
ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId)));
ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId)));
assertThat(response1).isNotNull();
memory.add(conversationId, response1.getResult().getOutput());
UserMessage userMessage2 = new UserMessage("What is my name?");
memory.add(conversationId, userMessage2);
ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId)));
ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId)));
assertThat(response2).isNotNull();
memory.add(conversationId, response2.getResult().getOutput());
@@ -312,7 +312,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
chatMemory.add(conversationId, prompt.getInstructions());
Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
ChatResponse chatResponse = chatModel.call(promptWithMemory);
ChatResponse chatResponse = this.chatModel.call(promptWithMemory);
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
while (chatResponse.hasToolCalls()) {
@@ -321,7 +321,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
chatMemory.add(conversationId, toolExecutionResult.conversationHistory()
.get(toolExecutionResult.conversationHistory().size() - 1));
promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
chatResponse = chatModel.call(promptWithMemory);
chatResponse = this.chatModel.call(promptWithMemory);
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
}
@@ -331,7 +331,7 @@ class OllamaChatModelIT extends BaseOllamaIT {
UserMessage newUserMessage = new UserMessage("What did I ask you earlier?");
chatMemory.add(conversationId, newUserMessage);
ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId)));
ChatResponse newResponse = this.chatModel.call(new Prompt(chatMemory.get(conversationId)));
assertThat(newResponse).isNotNull();
assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8");

View File

@@ -43,7 +43,6 @@ import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -74,6 +73,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;

View File

@@ -296,6 +296,31 @@ public class OpenAiApi {
});
}
// Package-private getters for mutate/copy
String getBaseUrl() {
return this.baseUrl;
}
ApiKey getApiKey() {
return this.apiKey;
}
MultiValueMap<String, String> getHeaders() {
return this.headers;
}
String getCompletionsPath() {
return this.completionsPath;
}
String getEmbeddingsPath() {
return this.embeddingsPath;
}
ResponseErrorHandler getResponseErrorHandler() {
return this.responseErrorHandler;
}
/**
* OpenAI Chat Completion Models.
* <p>
@@ -1193,7 +1218,7 @@ public class OpenAiApi {
*/
@JsonInclude(Include.NON_NULL)
public record WebSearchOptions(@JsonProperty("search_context_size") SearchContextSize searchContextSize,
@JsonProperty("user_location") UserLocation userLocation) {
@JsonProperty("user_location") UserLocation userLocation) {
/**
* High level guidance for the amount of context window space to use for the
@@ -1229,11 +1254,11 @@ public class OpenAiApi {
*/
@JsonInclude(Include.NON_NULL)
public record UserLocation(@JsonProperty("type") String type,
@JsonProperty("approximate") Approximate approximate) {
@JsonProperty("approximate") Approximate approximate) {
@JsonInclude(Include.NON_NULL)
public record Approximate(@JsonProperty("city") String city, @JsonProperty("country") String country,
@JsonProperty("region") String region, @JsonProperty("timezone") String timezone) {
@JsonProperty("region") String region, @JsonProperty("timezone") String timezone) {
}
}
@@ -1891,29 +1916,4 @@ public class OpenAiApi {
}
// Package-private getters for mutate/copy
String getBaseUrl() {
return this.baseUrl;
}
ApiKey getApiKey() {
return this.apiKey;
}
MultiValueMap<String, String> getHeaders() {
return this.headers;
}
String getCompletionsPath() {
return this.completionsPath;
}
String getEmbeddingsPath() {
return this.embeddingsPath;
}
ResponseErrorHandler getResponseErrorHandler() {
return this.responseErrorHandler;
}
}

View File

@@ -36,26 +36,29 @@ class OpenAiChatModelMutateTests {
private final OpenAiApi baseApi = OpenAiApi.builder().baseUrl("https://api.openai.com").apiKey("base-key").build();
private final OpenAiChatModel baseModel = OpenAiChatModel.builder()
.openAiApi(baseApi)
.openAiApi(this.baseApi)
.defaultOptions(OpenAiChatOptions.builder().model("gpt-3.5-turbo").build())
.build();
@Test
void testMutateCreatesDistinctClientsWithDifferentEndpointsAndModels() {
// Mutate for GPT-4
OpenAiApi gpt4Api = baseApi.mutate().baseUrl("https://api.openai.com").apiKey("your-api-key-for-gpt4").build();
OpenAiChatModel gpt4Model = baseModel.mutate()
OpenAiApi gpt4Api = this.baseApi.mutate()
.baseUrl("https://api.openai.com")
.apiKey("your-api-key-for-gpt4")
.build();
OpenAiChatModel gpt4Model = this.baseModel.mutate()
.openAiApi(gpt4Api)
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
.build();
ChatClient gpt4Client = ChatClient.builder(gpt4Model).build();
// Mutate for Llama
OpenAiApi llamaApi = baseApi.mutate()
OpenAiApi llamaApi = this.baseApi.mutate()
.baseUrl("https://your-custom-endpoint.com")
.apiKey("your-api-key-for-llama")
.build();
OpenAiChatModel llamaModel = baseModel.mutate()
OpenAiChatModel llamaModel = this.baseModel.mutate()
.openAiApi(llamaApi)
.defaultOptions(OpenAiChatOptions.builder().model("llama-70b").temperature(0.5).build())
.build();
@@ -72,29 +75,29 @@ class OpenAiChatModelMutateTests {
@Test
void testCloneCreatesDeepCopy() {
OpenAiChatModel clone = baseModel.clone();
assertThat(clone).isNotSameAs(baseModel);
assertThat(clone.toString()).isEqualTo(baseModel.toString());
OpenAiChatModel clone = this.baseModel.clone();
assertThat(clone).isNotSameAs(this.baseModel);
assertThat(clone.toString()).isEqualTo(this.baseModel.toString());
}
@Test
void mutateDoesNotAffectOriginal() {
OpenAiChatModel mutated = baseModel.mutate()
OpenAiChatModel mutated = this.baseModel.mutate()
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").build())
.build();
assertThat(mutated).isNotSameAs(baseModel);
assertThat(mutated).isNotSameAs(this.baseModel);
assertThat(mutated.getDefaultOptions().getModel()).isEqualTo("gpt-4");
assertThat(baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo");
assertThat(this.baseModel.getDefaultOptions().getModel()).isEqualTo("gpt-3.5-turbo");
}
@Test
void mutateHeadersCreatesDistinctHeaders() {
OpenAiApi mutatedApi = baseApi.mutate()
OpenAiApi mutatedApi = this.baseApi.mutate()
.headers(new LinkedMultiValueMap<>(java.util.Map.of("X-Test", java.util.List.of("value"))))
.build();
assertThat(mutatedApi.getHeaders()).containsKey("X-Test");
assertThat(baseApi.getHeaders()).doesNotContainKey("X-Test");
assertThat(this.baseApi.getHeaders()).doesNotContainKey("X-Test");
}
@Test
@@ -108,7 +111,9 @@ class OpenAiChatModelMutateTests {
@Test
void multipleSequentialMutationsProduceDistinctInstances() {
OpenAiChatModel m1 = baseModel.mutate().defaultOptions(OpenAiChatOptions.builder().model("m1").build()).build();
OpenAiChatModel m1 = this.baseModel.mutate()
.defaultOptions(OpenAiChatOptions.builder().model("m1").build())
.build();
OpenAiChatModel m2 = m1.mutate().defaultOptions(OpenAiChatOptions.builder().model("m2").build()).build();
OpenAiChatModel m3 = m2.mutate().defaultOptions(OpenAiChatOptions.builder().model("m3").build()).build();
assertThat(m1).isNotSameAs(m2);
@@ -120,8 +125,8 @@ class OpenAiChatModelMutateTests {
@Test
void mutateAndCloneAreEquivalent() {
OpenAiChatModel mutated = baseModel.mutate().build();
OpenAiChatModel cloned = baseModel.clone();
OpenAiChatModel mutated = this.baseModel.mutate().build();
OpenAiChatModel cloned = this.baseModel.clone();
assertThat(mutated.toString()).isEqualTo(cloned.toString());
assertThat(mutated).isNotSameAs(cloned);
}

View File

@@ -18,7 +18,6 @@ package org.springframework.ai.openai.chat;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.util.List;
import java.util.Map;

View File

@@ -643,14 +643,14 @@ public class OpenAiChatModelIT extends AbstractIT {
UserMessage userMessage1 = new UserMessage("My name is James Bond");
memory.add(conversationId, userMessage1);
ChatResponse response1 = chatModel.call(new Prompt(memory.get(conversationId)));
ChatResponse response1 = this.chatModel.call(new Prompt(memory.get(conversationId)));
assertThat(response1).isNotNull();
memory.add(conversationId, response1.getResult().getOutput());
UserMessage userMessage2 = new UserMessage("What is my name?");
memory.add(conversationId, userMessage2);
ChatResponse response2 = chatModel.call(new Prompt(memory.get(conversationId)));
ChatResponse response2 = this.chatModel.call(new Prompt(memory.get(conversationId)));
assertThat(response2).isNotNull();
memory.add(conversationId, response2.getResult().getOutput());
@@ -675,7 +675,7 @@ public class OpenAiChatModelIT extends AbstractIT {
chatMemory.add(conversationId, prompt.getInstructions());
Prompt promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
ChatResponse chatResponse = chatModel.call(promptWithMemory);
ChatResponse chatResponse = this.chatModel.call(promptWithMemory);
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
while (chatResponse.hasToolCalls()) {
@@ -684,7 +684,7 @@ public class OpenAiChatModelIT extends AbstractIT {
chatMemory.add(conversationId, toolExecutionResult.conversationHistory()
.get(toolExecutionResult.conversationHistory().size() - 1));
promptWithMemory = new Prompt(chatMemory.get(conversationId), chatOptions);
chatResponse = chatModel.call(promptWithMemory);
chatResponse = this.chatModel.call(promptWithMemory);
chatMemory.add(conversationId, chatResponse.getResult().getOutput());
}
@@ -694,21 +694,12 @@ public class OpenAiChatModelIT extends AbstractIT {
UserMessage newUserMessage = new UserMessage("What did I ask you earlier?");
chatMemory.add(conversationId, newUserMessage);
ChatResponse newResponse = chatModel.call(new Prompt(chatMemory.get(conversationId)));
ChatResponse newResponse = this.chatModel.call(new Prompt(chatMemory.get(conversationId)));
assertThat(newResponse).isNotNull();
assertThat(newResponse.getResult().getOutput().getText()).contains("6").contains("8");
}
static class MathTools {
@Tool(description = "Multiply the two numbers")
double multiply(double a, double b) {
return a * b;
}
}
@Test
void webSearchAnnotationsTest() {
UserMessage userMessage = new UserMessage("What is the latest news on the Mars rover?");
@@ -779,4 +770,13 @@ public class OpenAiChatModelIT extends AbstractIT {
}
static class MathTools {
@Tool(description = "Multiply the two numbers")
double multiply(double a, double b) {
return a * b;
}
}
}

View File

@@ -28,10 +28,10 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.openai.OpenAiChatModel;

View File

@@ -1,3 +1,19 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.chat.client;
import java.util.List;
@@ -41,7 +57,7 @@ class OpenAiChatClientMemoryAdvisorReproIT {
.build();
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
// Act: call should succeed without exception (issue #2339 is fixed)
chatClient.prompt(prompt).call().chatResponse(); // Should not throw

View File

@@ -57,7 +57,7 @@ public class ReReadingAdvisor implements BaseAdvisor {
@Override
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
String augmentedUserText = PromptTemplate.builder()
.template(re2AdviseTemplate)
.template(this.re2AdviseTemplate)
.variables(Map.of("re2_input_query", chatClientRequest.prompt().getUserMessage().getText()))
.build()
.render();

View File

@@ -81,7 +81,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
var advisor = createAdvisor(chatMemory);
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
// Create a prompt with multiple user messages
List<Message> messages = new ArrayList<>();
@@ -131,7 +131,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
// Create advisor with the conversation ID
var advisor = createAdvisor(chatMemory);
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
// Act - Create a list of messages for the prompt
List<Message> messages = new ArrayList<>();
@@ -193,7 +193,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
// Create advisor without a default conversation ID
var advisor = createAdvisor(chatMemory);
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
String question = "What is the capital of Germany?";
@@ -231,7 +231,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
// Create advisor without a default conversation ID
var advisor = createAdvisor(chatMemory);
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
// Act - First conversation
String answer1 = chatClient.prompt()
@@ -316,7 +316,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
// Create advisor without a default conversation ID
var advisor = createAdvisor(chatMemory);
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
// Act - Send a question to a non-existent conversation
String question = "Do you remember our previous conversation?";
@@ -375,7 +375,7 @@ public abstract class AbstractChatMemoryAdvisorIT {
var advisor = createAdvisor(chatMemory);
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
List<String> responseList = new ArrayList<>();
for (String message : List.of("My name is Charlie.", "I am 30 years old.", "I live in London.")) {

View File

@@ -91,7 +91,7 @@ public class MessageChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
.conversationId(conversationId)
.build();
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
// Create a prompt with multiple user messages
List<Message> messages = new ArrayList<>();
@@ -151,7 +151,7 @@ public class MessageChatMemoryAdvisorIT extends AbstractChatMemoryAdvisorIT {
.conversationId(conversationId)
.build();
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
// Act - Use streaming API
String userInput = "Tell me a short joke about programming";

View File

@@ -50,25 +50,25 @@ class MultiOpenAiClientIT {
@Test
void multiClientFlow() {
// Derive a new OpenAiApi for Groq (Llama3)
OpenAiApi groqApi = baseOpenAiApi.mutate()
OpenAiApi groqApi = this.baseOpenAiApi.mutate()
.baseUrl("https://api.groq.com/openai")
.apiKey(System.getenv("GROQ_API_KEY"))
.build();
// Derive a new OpenAiApi for OpenAI GPT-4
OpenAiApi gpt4Api = baseOpenAiApi.mutate()
OpenAiApi gpt4Api = this.baseOpenAiApi.mutate()
.baseUrl("https://api.openai.com")
.apiKey(System.getenv("OPENAI_API_KEY"))
.build();
// Derive a new OpenAiChatModel for Groq
OpenAiChatModel groqModel = baseChatModel.mutate()
OpenAiChatModel groqModel = this.baseChatModel.mutate()
.openAiApi(groqApi)
.defaultOptions(OpenAiChatOptions.builder().model("llama3-70b-8192").temperature(0.5).build())
.build();
// Derive a new OpenAiChatModel for GPT-4
OpenAiChatModel gpt4Model = baseChatModel.mutate()
OpenAiChatModel gpt4Model = this.baseChatModel.mutate()
.openAiApi(gpt4Api)
.defaultOptions(OpenAiChatOptions.builder().model("gpt-4").temperature(0.7).build())
.build();

Some files were not shown because too many files have changed in this diff Show More