Remove deprecations from 1.0.0-M6

- Remove deprecations from models, vector stores and usage
- Deprecations from FunctionCallback and ObservationContext/Convention will be in a separate PR

Models updates
  - Remove AbstractToolCallSupport from the models which use ToolCallingManager
  - Remove deprecated constructors and their usage
  - Remove FunctionCallbackResolver and FunctionCallbacks usage in the models

- Add back deprecations for VectorStoreChatMemoryAdvisor until builder is fixed

- Update OpenAiPaymentTransactionIT to use ToolCallbackResolver in config

Signed-off-by: Ilayaperumal Gopinathan <ilayaperumal.gopinathan@broadcom.com>
This commit is contained in:
Ilayaperumal Gopinathan
2025-02-17 14:41:27 +00:00
committed by Mark Pollack
parent a1e417f350
commit ded9facfe5
80 changed files with 296 additions and 1962 deletions

View File

@@ -30,13 +30,6 @@ import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
@@ -59,7 +52,6 @@ import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -72,10 +64,12 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
@@ -94,7 +88,7 @@ import org.springframework.util.StringUtils;
* @author Alexandros Pappas
* @since 1.0.0
*/
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
public class AnthropicChatModel implements ChatModel {
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getValue();
@@ -135,111 +129,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi) {
this(anthropicApi,
AnthropicChatOptions.builder()
.model(DEFAULT_MODEL_NAME)
.maxTokens(DEFAULT_MAX_TOKENS)
.temperature(DEFAULT_TEMPERATURE)
.build());
}
/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) {
this(anthropicApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate) {
this(anthropicApi, defaultOptions, retryTemplate, null);
}
/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by its name.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver) {
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, List.of());
}
/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by its name.
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
* calls.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver,
List<FunctionCallback> toolFunctionCallbacks) {
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, toolFunctionCallbacks,
ObservationRegistry.NOOP);
}
/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by its name.
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
* calls.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, @Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry) {
this(anthropicApi, defaultOptions,
LegacyToolCallingManager.builder()
.functionCallbackResolver(functionCallbackResolver)
.functionCallbacks(toolFunctionCallbacks)
.build(),
retryTemplate, observationRegistry);
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
+ "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, AnthropicChatOptions.builder().build(), List.of());
Assert.notNull(anthropicApi, "anthropicApi cannot be null");
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
@@ -488,10 +380,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
AnthropicChatOptions.class);
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
AnthropicChatOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
AnthropicChatOptions.class);
@@ -648,10 +536,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
private FunctionCallbackResolver functionCallbackResolver;
private List<FunctionCallback> toolCallbacks;
private ToolCallingManager toolCallingManager;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
@@ -679,18 +563,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
return this;
}
@Deprecated
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}
@Deprecated
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.toolCallbacks = toolCallbacks;
return this;
}
public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
@@ -698,22 +570,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
public AnthropicChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver cannot be set when toolCallingManager is set");
Assert.isNull(toolCallbacks, "toolCallbacks cannot be set when toolCallingManager is set");
return new AnthropicChatModel(anthropicApi, defaultOptions, toolCallingManager, retryTemplate,
observationRegistry);
}
if (functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager cannot be set when functionCallbackResolver is set");
List<FunctionCallback> toolCallbacks = this.toolCallbacks != null ? this.toolCallbacks : List.of();
return new AnthropicChatModel(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver,
toolCallbacks, observationRegistry);
}
return new AnthropicChatModel(anthropicApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
observationRegistry);
}

View File

@@ -34,6 +34,8 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
@@ -171,8 +173,7 @@ public class AnthropicChatModelObservationIT {
public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi,
TestObservationRegistry observationRegistry) {
return new AnthropicChatModel(anthropicApi, AnthropicChatOptions.builder().build(),
RetryTemplate.defaultInstance(), new DefaultFunctionCallbackResolver(), List.of(),
observationRegistry);
ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry);
}
}

View File

@@ -73,7 +73,6 @@ import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -86,16 +85,12 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -120,7 +115,7 @@ import org.springframework.util.CollectionUtils;
* @see com.azure.ai.openai.OpenAIClient
* @since 1.0.0
*/
public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel {
public class AzureOpenAiChatModel implements ChatModel {
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModel.class);
@@ -162,53 +157,8 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
*/
private final ToolCallingManager toolCallingManager;
@Deprecated
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
this(openAIClientBuilder,
AzureOpenAiChatOptions.builder()
.deploymentName(DEFAULT_DEPLOYMENT_NAME)
.temperature(DEFAULT_TEMPERATURE)
.build());
}
@Deprecated
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options) {
this(openAIClientBuilder, options, null);
}
@Deprecated
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
FunctionCallbackResolver functionCallbackResolver) {
this(openAIClientBuilder, options, functionCallbackResolver, List.of());
}
@Deprecated
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks) {
this(openAIClientBuilder, options, functionCallbackResolver, toolFunctionCallbacks, ObservationRegistry.NOOP);
}
@Deprecated
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry) {
this(openAIClientBuilder, options,
LegacyToolCallingManager.builder()
.functionCallbackResolver(functionCallbackResolver)
.functionCallbacks(toolFunctionCallbacks)
.build(),
observationRegistry);
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
+ "Please use the AzureOpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions defaultOptions,
ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, AzureOpenAiChatOptions.builder().build(), List.of());
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
@@ -534,10 +484,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
if (prompt.getOptions() != null) {
AzureOpenAiChatOptions updatedRuntimeOptions;
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, AzureOpenAiChatOptions.class);
}
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions,
ToolCallingChatOptions.class, AzureOpenAiChatOptions.class);
@@ -668,10 +614,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
AzureOpenAiChatOptions.class);
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
AzureOpenAiChatOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
AzureOpenAiChatOptions.class);
@@ -978,10 +920,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
private ToolCallingManager toolCallingManager;
private FunctionCallbackResolver functionCallbackResolver;
private List<FunctionCallback> toolFunctionCallbacks;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
private Builder() {
@@ -1002,18 +940,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
return this;
}
@Deprecated
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}
@Deprecated
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}
public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
@@ -1021,29 +947,9 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
public AzureOpenAiChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver cannot be set when toolCallingManager is set");
Assert.isNull(toolFunctionCallbacks,
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, toolCallingManager,
observationRegistry);
}
if (functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager cannot be set when functionCallbackResolver is set");
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
: List.of();
return new Builder().openAIClientBuilder(openAIClientBuilder)
.defaultOptions(defaultOptions)
.functionCallbackResolver(functionCallbackResolver)
.toolFunctionCallbacks(toolCallbacks)
.observationRegistry(observationRegistry)
.build();
}
return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
observationRegistry);
}

View File

@@ -82,7 +82,6 @@ import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -95,10 +94,6 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
@@ -135,7 +130,7 @@ import org.springframework.util.StringUtils;
* @author Jihoon Kim
* @since 1.0.0
*/
public class BedrockProxyChatModel extends AbstractToolCallSupport implements ChatModel {
public class BedrockProxyChatModel implements ChatModel {
private static final Logger logger = LoggerFactory.getLogger(BedrockProxyChatModel.class);
@@ -161,33 +156,10 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
*/
private ChatModelObservationConvention observationConvention;
/**
* @deprecated Use
* {@link #BedrockProxyChatModel(BedrockRuntimeClient, BedrockRuntimeAsyncClient, ToolCallingChatOptions, ObservationRegistry, ToolCallingManager)}
* instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions,
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
ObservationRegistry observationRegistry) {
this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, from(defaultOptions), observationRegistry,
LegacyToolCallingManager.builder()
.functionCallbackResolver(functionCallbackResolver)
.functionCallbacks(toolFunctionCallbacks)
.build());
}
public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, FunctionCallingOptions.builder().build(), List.of());
Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
@@ -199,21 +171,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
this.toolCallingManager = toolCallingManager;
}
@Deprecated
private static ToolCallingChatOptions from(FunctionCallingOptions options) {
return ToolCallingChatOptions.builder()
.model(options.getModel())
.maxTokens(options.getMaxTokens())
.stopSequences(options.getStopSequences())
.temperature(options.getTemperature())
.topP(options.getTopP())
.toolCallbacks(options.getFunctionCallbacks())
.toolNames(options.getFunctions())
.internalToolExecutionEnabled(options.getProxyToolCalls() != null ? !options.getProxyToolCalls() : false)
.toolContext(options.getToolContext())
.build();
}
private static ToolCallingChatOptions from(ChatOptions options) {
return ToolCallingChatOptions.builder()
.model(options.getModel())
@@ -295,9 +252,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
runtimeOptions = toolCallingChatOptions.copy();
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = from(functionCallingOptions);
}
else {
runtimeOptions = from(prompt.getOptions());
}
@@ -804,10 +758,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
private ToolCallingChatOptions defaultOptions = ToolCallingChatOptions.builder().build();
private FunctionCallbackResolver functionCallbackResolver;
private List<FunctionCallback> toolFunctionCallbacks;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
private ChatModelObservationConvention customObservationConvention;
@@ -824,160 +774,47 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
return this;
}
/**
* @deprecated Use {@link #credentialsProvider(AwsCredentialsProvider)} instead.
*/
@Deprecated
public Builder withCredentialsProvider(AwsCredentialsProvider credentialsProvider) {
Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null.");
this.credentialsProvider = credentialsProvider;
return this;
}
public Builder credentialsProvider(AwsCredentialsProvider credentialsProvider) {
Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null.");
this.credentialsProvider = credentialsProvider;
return this;
}
/**
* @deprecated Use {@link #region(Region)} instead.
*/
@Deprecated
public Builder withRegion(Region region) {
Assert.notNull(region, "'region' must not be null.");
this.region = region;
return this;
}
public Builder region(Region region) {
Assert.notNull(region, "'region' must not be null.");
this.region = region;
return this;
}
/**
* @deprecated Use {@link #timeout(Duration)} instead.
*/
@Deprecated
public Builder withTimeout(Duration timeout) {
Assert.notNull(timeout, "'timeout' must not be null.");
this.timeout = timeout;
return this;
}
public Builder timeout(Duration timeout) {
Assert.notNull(timeout, "'timeout' must not be null.");
this.timeout = timeout;
return this;
}
/**
* @deprecated Use {@link #defaultOptions(ToolCallingChatOptions)} instead.
*/
@Deprecated
public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) {
Assert.notNull(defaultOptions, "'defaultOptions' must not be null.");
return this.defaultOptions(ToolCallingChatOptions.builder()
.model(defaultOptions.getModel())
.maxTokens(defaultOptions.getMaxTokens())
.stopSequences(defaultOptions.getStopSequences())
.temperature(defaultOptions.getTemperature())
.topP(defaultOptions.getTopP())
.toolCallbacks(defaultOptions.getFunctionCallbacks())
.toolNames(defaultOptions.getFunctions())
.internalToolExecutionEnabled(
defaultOptions.getProxyToolCalls() != null ? !defaultOptions.getProxyToolCalls() : false)
.toolContext(defaultOptions.getToolContext())
.build());
}
public Builder defaultOptions(ToolCallingChatOptions defaultOptions) {
Assert.notNull(defaultOptions, "'defaultOptions' must not be null.");
this.defaultOptions = defaultOptions;
return this;
}
/**
* @deprecated To be removed after M6
*/
@Deprecated
public Builder withFunctionCallbackContext(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}
/**
* @deprecated To be removed after M6
*/
@Deprecated
public Builder withToolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}
/**
* @deprecated Use {@link #observationRegistry(ObservationRegistry)} instead.
*/
@Deprecated
public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
Assert.notNull(observationRegistry, "'observationRegistry' must not be null.");
this.observationRegistry = observationRegistry;
return this;
}
public Builder observationRegistry(ObservationRegistry observationRegistry) {
Assert.notNull(observationRegistry, "'observationRegistry' must not be null.");
this.observationRegistry = observationRegistry;
return this;
}
/**
* @deprecated Use
* {@link #customObservationConvention(ChatModelObservationConvention)} instead.
*/
@Deprecated
public Builder withCustomObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "'observationConvention' must not be null.");
this.customObservationConvention = observationConvention;
return this;
}
public Builder customObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "'observationConvention' must not be null.");
this.customObservationConvention = observationConvention;
return this;
}
/**
* @deprecated Use {@link #bedrockRuntimeClient(BedrockRuntimeClient)} instead.
*/
@Deprecated
public Builder withBedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
return this;
}
public Builder bedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) {
this.bedrockRuntimeClient = bedrockRuntimeClient;
return this;
}
/**
* @deprecated Use {@link #bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient)}
* instead.
*/
@Deprecated
public Builder withBedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
return this;
}
public Builder bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
return this;
@@ -1013,23 +850,11 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
BedrockProxyChatModel bedrockProxyChatModel = null;
if (this.toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver cannot be set when toolCallingManager is set");
Assert.isNull(toolFunctionCallbacks,
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry,
this.toolCallingManager);
}
else if (this.functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager cannot be set when functionCallbackResolver is set");
bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackResolver,
this.toolFunctionCallbacks, this.observationRegistry);
}
else {
bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry,

View File

@@ -36,7 +36,6 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
@@ -367,7 +366,7 @@ class BedrockConverseChatClientIT {
// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
.options(FunctionCallingOptions.builder().model(modelName).build())
.options(ToolCallingChatOptions.builder().model(modelName).build())
.user(u -> u.text("Explain what do you see on this picture?")
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))
.call()
@@ -389,7 +388,7 @@ class BedrockConverseChatClientIT {
// @formatter:off
String response = ChatClient.create(this.chatModel).prompt()
// TODO consider adding model(...) method to ChatClient as a shortcut to
.options(FunctionCallingOptions.builder().model(modelName).build())
.options(ToolCallingChatOptions.builder().model(modelName).build())
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
.call()
.content();

View File

@@ -21,7 +21,7 @@ import java.time.Duration;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
@@ -42,7 +42,7 @@ public class BedrockConverseTestConfiguration {
.region(Region.US_EAST_1)
// .region(Region.US_EAST_1)
.timeout(Duration.ofSeconds(120))
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
.defaultOptions(ToolCallingChatOptions.builder().model(modelId).build())
.build();
}

View File

@@ -16,11 +16,7 @@
package org.springframework.ai.bedrock.converse;
import java.util.List;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
@@ -40,8 +36,9 @@ import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.function.FunctionToolCallback;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.isA;
@@ -61,8 +58,10 @@ public class BedrockConverseUsageAggregationTests {
@BeforeEach
public void beforeEach() {
this.chatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient,
FunctionCallingOptions.builder().build(), null, List.of(), ObservationRegistry.NOOP);
this.chatModel = BedrockProxyChatModel.builder()
.bedrockRuntimeClient(this.bedrockRuntimeClient)
.bedrockRuntimeAsyncClient(this.bedrockRuntimeAsyncClient)
.build();
}
@Test
@@ -138,14 +137,13 @@ public class BedrockConverseUsageAggregationTests {
given(this.bedrockRuntimeClient.converse(isA(ConverseRequest.class))).willReturn(converseResponseToolUse)
.willReturn(converseResponseFinal);
FunctionCallback functionCallback = FunctionCallback.builder()
.function("getCurrentWeather", (Request request) -> "15.0°C")
ToolCallback toolCallback = FunctionToolCallback.builder("getCurrentWeather", (Request request) -> "15.0°C")
.description("Gets the weather in location")
.inputType(Request.class)
.build();
var result = this.chatModel.call(new Prompt("What is the weather in Paris?",
FunctionCallingOptions.builder().functionCallbacks(functionCallback).build()));
ToolCallingChatOptions.builder().toolCallbacks(toolCallback).build()));
assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getText())

View File

@@ -92,7 +92,7 @@ class BedrockProxyChatModelIT {
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage),
FunctionCallingOptions.builder().model(modelName).build());
ToolCallingChatOptions.builder().model(modelName).build());
ChatResponse response = this.chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0);
@@ -128,7 +128,7 @@ class BedrockProxyChatModelIT {
@Test
void streamingWithTokenUsage() {
var promptOptions = FunctionCallingOptions.builder().temperature(0.0).build();
var promptOptions = ToolCallingChatOptions.builder().temperature(0.0).build();
var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
@@ -255,9 +255,8 @@ class BedrockProxyChatModelIT {
List<Message> messages = new ArrayList<>(List.of(userMessage));
var promptOptions = FunctionCallingOptions.builder()
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
var promptOptions = ToolCallingChatOptions.builder()
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
@@ -305,10 +304,9 @@ class BedrockProxyChatModelIT {
List<Message> messages = new ArrayList<>(List.of(userMessage));
var promptOptions = FunctionCallingOptions.builder()
var promptOptions = ToolCallingChatOptions.builder()
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description(
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
.inputType(MockWeatherService.Request.class)
@@ -333,7 +331,7 @@ class BedrockProxyChatModelIT {
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
// @formatter:off
ChatResponse response = ChatClient.create(this.chatModel).prompt()
.options(FunctionCallingOptions.builder().model(model).build())
.options(ToolCallingChatOptions.builder().model(model).build())
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.call()
.chatResponse();
@@ -348,7 +346,7 @@ class BedrockProxyChatModelIT {
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
// @formatter:off
ChatResponse response = ChatClient.create(this.chatModel).prompt()
.options(FunctionCallingOptions.builder().model(model).build())
.options(ToolCallingChatOptions.builder().model(model).build())
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.stream()
.chatResponse()

View File

@@ -34,7 +34,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
@@ -67,7 +67,7 @@ public class BedrockProxyChatModelObservationIT {
@Test
void observationForChatOperation() {
var options = FunctionCallingOptions.builder()
var options = ToolCallingChatOptions.builder()
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
.maxTokens(2048)
.stopSequences(List.of("this-is-the-end"))
@@ -89,7 +89,7 @@ public class BedrockProxyChatModelObservationIT {
@Test
void observationForStreamingChatOperation() {
var options = FunctionCallingOptions.builder()
var options = ToolCallingChatOptions.builder()
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
.maxTokens(2048)
.stopSequences(List.of("this-is-the-end"))
@@ -170,10 +170,10 @@ public class BedrockProxyChatModelObservationIT {
String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0";
return BedrockProxyChatModel.builder()
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
.withRegion(Region.US_EAST_1)
.withObservationRegistry(observationRegistry)
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
.region(Region.US_EAST_1)
.observationRegistry(observationRegistry)
.defaultOptions(ToolCallingChatOptions.builder().model(modelId).build())
.build();
}

View File

@@ -33,6 +33,7 @@ import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
@@ -185,7 +186,7 @@ public class BedrockNovaChatClientIT {
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
.region(Region.US_EAST_1)
.timeout(Duration.ofSeconds(120))
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
.defaultOptions(ToolCallingChatOptions.builder().model(modelId).build())
.build();
}

View File

@@ -43,8 +43,8 @@ public final class BedrockConverseChatModelMain {
var prompt = new Prompt("Tell me a joke?", ChatOptions.builder().model(modelId).build());
var chatModel = BedrockProxyChatModel.builder()
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
.withRegion(Region.US_EAST_1)
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
.region(Region.US_EAST_1)
.build();
var chatResponse = chatModel.call(prompt);

View File

@@ -24,8 +24,8 @@ import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
import org.springframework.ai.bedrock.converse.MockWeatherService;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.function.FunctionToolCallback;
/**
* Used for reverse engineering the protocol
@@ -48,18 +48,17 @@ public final class BedrockConverseChatModelMain3 {
// "What's the weather like in San Francisco, Tokyo, and Paris? Return the
// temperature in Celsius.",
"What's the weather like in Paris? Return the temperature in Celsius.",
FunctionCallingOptions.builder()
ToolCallingChatOptions.builder()
.model(modelId)
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build());
BedrockProxyChatModel chatModel = BedrockProxyChatModel.builder()
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
.withRegion(Region.US_EAST_1)
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
.region(Region.US_EAST_1)
.build();
var response = chatModel.call(prompt);

View File

@@ -89,7 +89,7 @@ import org.springframework.util.MimeType;
* @author Alexandros Pappas
* @since 1.0.0
*/
public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel {
public class MistralAiChatModel implements ChatModel {
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
@@ -121,73 +121,9 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
/**
* @deprecated Use {@link MistralAiChatModel.Builder}.
*/
@Deprecated
public MistralAiChatModel(MistralAiApi mistralAiApi) {
this(mistralAiApi,
MistralAiChatOptions.builder()
.temperature(0.7)
.topP(1.0)
.safePrompt(false)
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
.build());
}
/**
* @deprecated Use {@link MistralAiChatModel.Builder}.
*/
@Deprecated
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* @deprecated Use {@link MistralAiChatModel.Builder}.
*/
@Deprecated
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver, @Nullable RetryTemplate retryTemplate) {
this(mistralAiApi, options, functionCallbackResolver, List.of(), retryTemplate);
}
/**
* @deprecated Use {@link MistralAiChatModel.Builder}.
*/
@Deprecated
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate) {
this(mistralAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate,
ObservationRegistry.NOOP);
}
/**
* @deprecated Use {@link MistralAiChatModel.Builder}.
*/
@Deprecated
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
this(mistralAiApi, options,
LegacyToolCallingManager.builder()
.functionCallbackResolver(functionCallbackResolver)
.functionCallbacks(toolFunctionCallbacks)
.build(),
retryTemplate, observationRegistry);
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
+ "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions,
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, MistralAiChatOptions.builder().build(), List.of());
Assert.notNull(mistralAiApi, "mistralAiApi cannot be null");
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
@@ -594,15 +530,11 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
.temperature(0.7)
.topP(1.0)
.safePrompt(false)
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
.model(MistralAiApi.ChatModel.SMALL.getValue())
.build();
private ToolCallingManager toolCallingManager;
private FunctionCallbackResolver functionCallbackResolver;
private List<FunctionCallback> toolFunctionCallbacks;
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
@@ -625,18 +557,6 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
return this;
}
@Deprecated
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}
@Deprecated
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}
public Builder retryTemplate(RetryTemplate retryTemplate) {
this.retryTemplate = retryTemplate;
return this;
@@ -649,25 +569,9 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
public MistralAiChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver cannot be set when toolCallingManager is set");
Assert.isNull(toolFunctionCallbacks,
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
return new MistralAiChatModel(mistralAiApi, defaultOptions, toolCallingManager, retryTemplate,
observationRegistry);
}
if (functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager cannot be set when functionCallbackResolver is set");
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
: List.of();
return new MistralAiChatModel(mistralAiApi, defaultOptions, functionCallbackResolver, toolCallbacks,
retryTemplate, observationRegistry);
}
return new MistralAiChatModel(mistralAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
observationRegistry);
}

View File

@@ -265,12 +265,6 @@ public class MistralAiApi {
public enum ChatModel implements ChatModelDescription {
// @formatter:off
@Deprecated(forRemoval = true, since = "1.0.0-M6")
OPEN_MISTRAL_7B("open-mistral-7b"),
@Deprecated(forRemoval = true, since = "1.0.0-M6")
OPEN_MIXTRAL_7B("open-mixtral-8x7b"),
@Deprecated(forRemoval = true, since = "1.0.0-M6")
OPEN_MIXTRAL_22B("open-mixtral-8x22b"),
// Premier Models
CODESTRAL("codestral-latest"),
LARGE("mistral-large-latest"),

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -32,7 +32,6 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
@@ -70,7 +69,7 @@ public class MistralAiChatModelObservationIT {
@Test
void observationForChatOperation() {
var options = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
.model(MistralAiApi.ChatModel.SMALL.getValue())
.maxTokens(2048)
.stop(List.of("this-is-the-end"))
.temperature(0.7)
@@ -91,7 +90,7 @@ public class MistralAiChatModelObservationIT {
@Test
void observationForStreamingChatOperation() {
var options = MistralAiChatOptions.builder()
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
.model(MistralAiApi.ChatModel.SMALL.getValue())
.maxTokens(2048)
.stop(List.of("this-is-the-end"))
.temperature(0.7)
@@ -125,12 +124,12 @@ public class MistralAiChatModelObservationIT {
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
.that()
.hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
.hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.SMALL.getValue())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.CHAT.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.MISTRAL_AI.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
MistralAiApi.ChatModel.SMALL.getValue())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel()
: KeyValue.NONE_VALUE)
@@ -181,9 +180,12 @@ public class MistralAiChatModelObservationIT {
@Bean
public MistralAiChatModel openAiChatModel(MistralAiApi mistralAiApi,
TestObservationRegistry observationRegistry) {
return new MistralAiChatModel(mistralAiApi, MistralAiChatOptions.builder().build(),
new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(),
observationRegistry);
return MistralAiChatModel.builder()
.mistralAiApi(mistralAiApi)
.defaultOptions(MistralAiChatOptions.builder().build())
.retryTemplate(RetryTemplate.defaultInstance())
.observationRegistry(observationRegistry)
.build();
}
}

View File

@@ -77,14 +77,16 @@ public class MistralAiRetryTests {
this.retryListener = new TestRetryListener();
this.retryTemplate.registerListener(this.retryListener);
this.chatModel = new MistralAiChatModel(this.mistralAiApi,
MistralAiChatOptions.builder()
.temperature(0.7)
.topP(1.0)
.safePrompt(false)
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
.build(),
null, this.retryTemplate);
this.chatModel = MistralAiChatModel.builder()
.mistralAiApi(this.mistralAiApi)
.defaultOptions(MistralAiChatOptions.builder()
.temperature(0.7)
.topP(1.0)
.safePrompt(false)
.model(MistralAiApi.ChatModel.SMALL.getValue())
.build())
.retryTemplate(this.retryTemplate)
.build();
this.embeddingModel = new MistralAiEmbeddingModel(this.mistralAiApi, MetadataMode.EMBED,
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
this.retryTemplate);

View File

@@ -43,8 +43,10 @@ public class MistralAiTestConfiguration {
@Bean
public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) {
return new MistralAiChatModel(mistralAiApi,
MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getValue()).build());
return MistralAiChatModel.builder()
.mistralAiApi(mistralAiApi)
.defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL.getValue()).build())
.build();
}
}

View File

@@ -47,7 +47,7 @@ public class MistralAiApiIT {
void chatCompletionEntity() {
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
ResponseEntity<ChatCompletion> response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest(
List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false));
List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false));
assertThat(response).isNotNull();
assertThat(response.getBody()).isNotNull();
@@ -64,7 +64,7 @@ public class MistralAiApiIT {
""", Role.SYSTEM);
ResponseEntity<ChatCompletion> response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest(
List.of(systemMessage, userMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false));
List.of(systemMessage, userMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false));
assertThat(response).isNotNull();
assertThat(response.getBody()).isNotNull();
@@ -74,7 +74,7 @@ public class MistralAiApiIT {
void chatCompletionStream() {
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
Flux<ChatCompletionChunk> response = this.mistralAiApi.chatCompletionStream(new ChatCompletionRequest(
List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, true));
List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, true));
assertThat(response).isNotNull();
assertThat(response.collectList().block()).isNotNull();

View File

@@ -28,13 +28,6 @@ import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
@@ -45,7 +38,6 @@ import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -57,9 +49,9 @@ import org.springframework.ai.chat.observation.DefaultChatModelObservationConven
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
@@ -70,6 +62,8 @@ import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
@@ -89,7 +83,7 @@ import org.springframework.util.StringUtils;
* @author Ilayaperumal Gopinathan
* @since 1.0.0
*/
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
public class OllamaChatModel implements ChatModel {
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModel.class);
@@ -125,24 +119,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
@Deprecated
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
@Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry,
ModelManagementOptions modelManagementOptions) {
this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks),
observationRegistry, modelManagementOptions);
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
+ "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, OllamaOptions.builder().build(), List.of());
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
@@ -381,10 +359,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
OllamaOptions.class);
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
OllamaOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OllamaOptions.class);
@@ -540,10 +514,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
private ToolCallingManager toolCallingManager;
private FunctionCallbackResolver functionCallbackResolver;
private List<FunctionCallback> toolFunctionCallbacks;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
@@ -566,18 +536,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
return this;
}
@Deprecated
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}
@Deprecated
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}
public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
@@ -590,24 +548,9 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
public OllamaChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver must not be set when toolCallingManager is set");
Assert.isNull(toolFunctionCallbacks,
"toolFunctionCallbacks must not be set when toolCallingManager is set");
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
this.observationRegistry, this.modelManagementOptions);
}
if (functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager must not be set when functionCallbackResolver is set");
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
: List.of();
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver,
toolCallbacks, this.observationRegistry, this.modelManagementOptions);
}
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
this.observationRegistry, this.modelManagementOptions);
}

View File

@@ -55,9 +55,11 @@ class OllamaChatModelTests {
@Test
void buildOllamaChatModelWithDeprecatedConstructor() {
ChatModel chatModel = new OllamaChatModel(this.ollamaApi,
OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), null, null, ObservationRegistry.NOOP,
ModelManagementOptions.builder().build());
ChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(this.ollamaApi)
.defaultOptions(OllamaOptions.builder().model(OllamaModel.MISTRAL).build())
.observationRegistry(ObservationRegistry.NOOP)
.build();
assertThat(chatModel).isNotNull();
}

View File

@@ -29,12 +29,6 @@ import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.lang.Nullable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
@@ -50,7 +44,6 @@ import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
@@ -67,6 +60,9 @@ import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
@@ -79,6 +75,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
@@ -111,7 +108,7 @@ import org.springframework.util.StringUtils;
* @see StreamingChatModel
* @see OpenAiApi
*/
public class OpenAiChatModel extends AbstractToolCallSupport implements ChatModel {
public class OpenAiChatModel implements ChatModel {
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class);
@@ -146,96 +143,8 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
/**
* Creates an instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @throws IllegalArgumentException if openAiApi is null
* @deprecated Use OpenAiChatModel.Builder.
*/
@Deprecated
public OpenAiChatModel(OpenAiApi openAiApi) {
this(openAiApi, OpenAiChatOptions.builder().model(OpenAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build());
}
/**
* Initializes an instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @param options The OpenAiChatOptions to configure the chat model.
* @deprecated Use OpenAiChatModel.Builder.
*/
@Deprecated
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) {
this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Initializes a new instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @param options The OpenAiChatOptions to configure the chat model.
* @param functionCallbackResolver The function callback resolver.
* @param retryTemplate The retry template.
* @deprecated Use OpenAiChatModel.Builder.
*/
@Deprecated
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
this(openAiApi, options, functionCallbackResolver, List.of(), retryTemplate);
}
/**
* Initializes a new instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @param options The OpenAiChatOptions to configure the chat model.
* @param functionCallbackResolver The function callback resolver.
* @param toolFunctionCallbacks The tool function callbacks.
* @param retryTemplate The retry template.
* @deprecated Use OpenAiChatModel.Builder.
*/
@Deprecated
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate) {
this(openAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate,
ObservationRegistry.NOOP);
}
/**
* Initializes a new instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @param options The OpenAiChatOptions to configure the chat model.
* @param functionCallbackResolver The function callback resolver.
* @param toolFunctionCallbacks The tool function callbacks.
* @param retryTemplate The retry template.
* @param observationRegistry The ObservationRegistry used for instrumentation.
* @deprecated Use OpenAiChatModel.Builder or OpenAiChatModel(OpenAiApi,
* OpenAiChatOptions, ToolCallingManager, RetryTemplate, ObservationRegistry).
*/
@Deprecated
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
@Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
this(openAiApi, options,
LegacyToolCallingManager.builder()
.functionCallbackResolver(functionCallbackResolver)
.functionCallbacks(toolFunctionCallbacks)
.build(),
retryTemplate, observationRegistry);
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
+ "Please use the OpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager,
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, OpenAiChatOptions.builder().build(), List.of());
Assert.notNull(openAiApi, "openAiApi cannot be null");
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
@@ -777,10 +686,6 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
private ToolCallingManager toolCallingManager;
private FunctionCallbackResolver functionCallbackResolver;
private List<FunctionCallback> toolFunctionCallbacks;
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
@@ -803,18 +708,6 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
return this;
}
@Deprecated
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}
@Deprecated
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}
public Builder retryTemplate(RetryTemplate retryTemplate) {
this.retryTemplate = retryTemplate;
return this;
@@ -827,25 +720,9 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
public OpenAiChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver cannot be set when toolCallingManager is set");
Assert.isNull(toolFunctionCallbacks,
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
return new OpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate,
observationRegistry);
}
if (functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager cannot be set when functionCallbackResolver is set");
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
: List.of();
return new OpenAiChatModel(openAiApi, defaultOptions, functionCallbackResolver, toolCallbacks,
retryTemplate, observationRegistry);
}
return new OpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
observationRegistry);
}

View File

@@ -85,97 +85,6 @@ public class OpenAiApi {
private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();
/**
* Create a new chat completion api with base URL set to https://api.openai.com
* @param apiKey OpenAI apiKey.
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0.M6")
public OpenAiApi(String apiKey) {
this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey);
}
/**
* Create a new chat completion api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0.M6")
public OpenAiApi(String baseUrl, String apiKey) {
this(baseUrl, apiKey, RestClient.builder(), WebClient.builder());
}
/**
* Create a new chat completion api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0.M6")
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder) {
this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
/**
* Create a new chat completion api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0.M6")
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder,
responseErrorHandler);
}
/**
* Create a new chat completion api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param completionsPath the path to the chat completions endpoint.
* @param embeddingsPath the path to the embeddings endpoint.
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0.M6")
public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath,
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), completionsPath, embeddingsPath,
restClientBuilder, webClientBuilder, responseErrorHandler);
}
/**
* Create a new chat completion api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param headers the http headers to use.
* @param completionsPath the path to the chat completions endpoint.
* @param embeddingsPath the path to the embeddings endpoint.
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
*/
@Deprecated(since = "1.0.0.M6")
public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers, String completionsPath,
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this(baseUrl, new SimpleApiKey(apiKey), headers, completionsPath, embeddingsPath, restClientBuilder,
webClientBuilder, responseErrorHandler);
}
/**
* Create a new chat completion api.
* @param baseUrl api base URL.

View File

@@ -57,79 +57,6 @@ public class OpenAiAudioApi {
private final WebClient webClient;
/**
* Create a new audio api.
* @param openAiToken OpenAI apiKey.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiAudioApi(String openAiToken) {
this(OpenAiApiConstants.DEFAULT_BASE_URL, openAiToken, RestClient.builder(), WebClient.builder(),
RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
/**
* Create a new audio api.
* @param baseUrl api base URL.
* @param openAiToken OpenAI apiKey.
* @param restClientBuilder RestClient builder.
* @param responseErrorHandler Response error handler.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {
Consumer<HttpHeaders> authHeaders;
if (openAiToken != null && !openAiToken.isEmpty()) {
authHeaders = h -> h.setBearerAuth(openAiToken);
}
else {
authHeaders = h -> {
};
}
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(authHeaders)
.defaultStatusHandler(responseErrorHandler)
.build();
this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(authHeaders).build();
}
/**
* Create a new audio api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiAudioApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, webClientBuilder,
responseErrorHandler);
}
/**
* Create a new audio api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param headers the http headers to use.
* @param restClientBuilder RestClient builder.
* @param webClientBuilder WebClient builder.
* @param responseErrorHandler Response error handler.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiAudioApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, webClientBuilder, responseErrorHandler);
}
/**
* Create a new audio api.
* @param baseUrl api base URL.
@@ -496,71 +423,26 @@ public class OpenAiAudioApi {
private Float speed;
/**
* @deprecated use {@link #model(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withModel(String model) {
this.model = model;
return this;
}
public Builder model(String model) {
this.model = model;
return this;
}
/**
* @deprecated use {@link #input(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withInput(String input) {
this.input = input;
return this;
}
public Builder input(String input) {
this.input = input;
return this;
}
/**
* @deprecated use {@link #voice(Voice)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withVoice(Voice voice) {
this.voice = voice;
return this;
}
public Builder voice(Voice voice) {
this.voice = voice;
return this;
}
/**
* @deprecated use {@link #responseFormat(AudioResponseFormat)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withResponseFormat(AudioResponseFormat responseFormat) {
this.responseFormat = responseFormat;
return this;
}
public Builder responseFormat(AudioResponseFormat responseFormat) {
this.responseFormat = responseFormat;
return this;
}
/**
* @deprecated use {@link #speed(Float)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withSpeed(Float speed) {
this.speed = speed;
return this;
}
public Builder speed(Float speed) {
this.speed = speed;
return this;
@@ -653,99 +535,36 @@ public class OpenAiAudioApi {
private GranularityType granularityType;
/**
* @deprecated use {@link #file(byte[])} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withFile(byte[] file) {
this.file = file;
return this;
}
public Builder file(byte[] file) {
this.file = file;
return this;
}
/**
* @deprecated use {@link #model(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withModel(String model) {
this.model = model;
return this;
}
public Builder model(String model) {
this.model = model;
return this;
}
/**
* @deprecated use {@link #language(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withLanguage(String language) {
this.language = language;
return this;
}
public Builder language(String language) {
this.language = language;
return this;
}
/**
* @deprecated use {@link #prompt(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withPrompt(String prompt) {
this.prompt = prompt;
return this;
}
public Builder prompt(String prompt) {
this.prompt = prompt;
return this;
}
/**
* @deprecated use {@link #responseFormat(TranscriptResponseFormat)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withResponseFormat(TranscriptResponseFormat response_format) {
this.responseFormat = response_format;
return this;
}
public Builder responseFormat(TranscriptResponseFormat responseFormat) {
this.responseFormat = responseFormat;
return this;
}
/**
* @deprecated use {@link #temperature(Float)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withTemperature(Float temperature) {
this.temperature = temperature;
return this;
}
public Builder temperature(Float temperature) {
this.temperature = temperature;
return this;
}
/**
* @deprecated use {@link #granularityType(GranularityType)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withGranularityType(GranularityType granularityType) {
this.granularityType = granularityType;
return this;
}
public Builder granularityType(GranularityType granularityType) {
this.granularityType = granularityType;
return this;
@@ -805,71 +624,26 @@ public class OpenAiAudioApi {
private Float temperature;
/**
* @deprecated use {@link #file(byte[])} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withFile(byte[] file) {
this.file = file;
return this;
}
public Builder file(byte[] file) {
this.file = file;
return this;
}
/**
* @deprecated use {@link #model(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withModel(String model) {
this.model = model;
return this;
}
public Builder model(String model) {
this.model = model;
return this;
}
/**
* @deprecated use {@link #prompt(String)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withPrompt(String prompt) {
this.prompt = prompt;
return this;
}
public Builder prompt(String prompt) {
this.prompt = prompt;
return this;
}
/**
* @deprecated use {@link #responseFormat(TranscriptResponseFormat)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withResponseFormat(TranscriptResponseFormat responseFormat) {
this.responseFormat = responseFormat;
return this;
}
public Builder responseFormat(TranscriptResponseFormat responseFormat) {
this.responseFormat = responseFormat;
return this;
}
/**
* @deprecated use {@link #temperature(Float)} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public Builder withTemperature(Float temperature) {
this.temperature = temperature;
return this;
}
public Builder temperature(Float temperature) {
this.temperature = temperature;
return this;

View File

@@ -17,7 +17,6 @@
package org.springframework.ai.openai.api;
import java.util.List;
import java.util.Map;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -30,7 +29,6 @@ import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.ResponseErrorHandler;
@@ -47,56 +45,6 @@ public class OpenAiImageApi {
private final RestClient restClient;
/**
* Create a new OpenAI Image api with base URL set to {@code https://api.openai.com}.
* @param openAiToken OpenAI apiKey.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiImageApi(String openAiToken) {
this(OpenAiApiConstants.DEFAULT_BASE_URL, openAiToken, RestClient.builder());
}
/**
* Create a new OpenAI Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
* @param openAiToken OpenAI apiKey.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
/**
* Create a new OpenAI Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
* @param apiKey OpenAI apiKey.
* @param restClientBuilder the rest client builder to use.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiImageApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, responseErrorHandler);
}
/**
* Create a new OpenAI Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
* @param apiKey OpenAI apiKey.
* @param headers the http headers to use.
* @param restClientBuilder the rest client builder to use.
* @param responseErrorHandler the response error handler to use.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiImageApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, responseErrorHandler);
}
/**
* Create a new OpenAI Image API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.

View File

@@ -52,33 +52,6 @@ public class OpenAiModerationApi {
private final ObjectMapper objectMapper;
/**
* Create a new OpenAI Moderation api with base URL set to https://api.openai.com
* @param openAiToken OpenAI apiKey.
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiModerationApi(String openAiToken) {
this(DEFAULT_BASE_URL, openAiToken, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
/**
* @deprecated use {@link Builder} instead.
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
public OpenAiModerationApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> {
if (openAiToken != null && !openAiToken.isEmpty()) {
h.setBearerAuth(openAiToken);
}
h.setContentType(MediaType.APPLICATION_JSON);
}).defaultStatusHandler(responseErrorHandler).build();
}
/**
* Create a new OpenAI Moderation API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.

View File

@@ -74,8 +74,10 @@ class ChatCompletionRequestTests {
@Test
void createRequestWithChatOptions() {
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
OpenAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build());
var client = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
.defaultOptions(OpenAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build())
.build();
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));
@@ -101,8 +103,10 @@ class ChatCompletionRequestTests {
void promptOptionsTools() {
final String TOOL_FUNCTION_NAME = "CurrentWeather";
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
OpenAiChatOptions.builder().model("DEFAULT_MODEL").build());
var client = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
.defaultOptions(OpenAiChatOptions.builder().model("DEFAULT_MODEL").build())
.build();
var prompt = client.buildRequestPrompt(new Prompt("Test message content",
OpenAiChatOptions.builder()
@@ -128,15 +132,17 @@ class ChatCompletionRequestTests {
void defaultOptionsTools() {
final String TOOL_FUNCTION_NAME = "CurrentWeather";
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
OpenAiChatOptions.builder()
.model("DEFAULT_MODEL")
.functionCallbacks(List.of(FunctionCallback.builder()
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build());
var client = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
.defaultOptions(OpenAiChatOptions.builder()
.model("DEFAULT_MODEL")
.functionCallbacks(List.of(FunctionCallback.builder()
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build())
.build();
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));

View File

@@ -61,28 +61,25 @@ public class OpenAiTestConfiguration {
@Bean
public OpenAiChatModel openAiChatModel(OpenAiApi api) {
OpenAiChatModel openAiChatModel = new OpenAiChatModel(api,
OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI).build());
return openAiChatModel;
return OpenAiChatModel.builder()
.openAiApi(api)
.defaultOptions(OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI).build())
.build();
}
@Bean
public OpenAiAudioTranscriptionModel openAiTranscriptionModel(OpenAiAudioApi api) {
OpenAiAudioTranscriptionModel openAiTranscriptionModel = new OpenAiAudioTranscriptionModel(api);
return openAiTranscriptionModel;
return new OpenAiAudioTranscriptionModel(api);
}
@Bean
public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) {
OpenAiAudioSpeechModel openAiAudioSpeechModel = new OpenAiAudioSpeechModel(api);
return openAiAudioSpeechModel;
return new OpenAiAudioSpeechModel(api);
}
@Bean
public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) {
OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi);
// openAiImageModel.setModel("foobar");
return openAiImageModel;
return new OpenAiImageModel(imageApi);
}
@Bean
@@ -92,8 +89,7 @@ public class OpenAiTestConfiguration {
@Bean
public OpenAiModerationModel openAiModerationClient(OpenAiModerationApi openAiModerationApi) {
OpenAiModerationModel openAiModerationModel = new OpenAiModerationModel(openAiModerationApi);
return openAiModerationModel;
return new OpenAiModerationModel(openAiModerationApi);
}
}

View File

@@ -46,7 +46,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiIT {
OpenAiApi openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
OpenAiApi openAiApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
@Test
void chatCompletionEntity() {

View File

@@ -52,7 +52,7 @@ public class OpenAiApiToolFunctionCallIT {
MockWeatherService weatherService = new MockWeatherService();
OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
OpenAiApi completionApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
private static <T> T fromJson(String json, Class<T> targetClass) {
try {

View File

@@ -74,7 +74,7 @@ public class MessageTypeContentTests {
@BeforeEach
public void beforeEach() {
this.chatModel = new OpenAiChatModel(this.openAiApi);
this.chatModel = OpenAiChatModel.builder().openAiApi(this.openAiApi).build();
}
@Test

View File

@@ -73,7 +73,7 @@ public class OpenAiChatModelAdditionalHttpHeadersIT {
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi);
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
}
}

View File

@@ -219,12 +219,12 @@ class OpenAiChatModelFunctionCallingIT {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi);
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
}
}

View File

@@ -62,7 +62,7 @@ public class OpenAiChatModelNoOpApiKeysIT {
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi);
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
}
}

View File

@@ -31,6 +31,8 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.openai.OpenAiChatModel;
@@ -169,14 +171,13 @@ public class OpenAiChatModelObservationIT {
@Bean
public OpenAiApi openAiApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
}
@Bean
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi, TestObservationRegistry observationRegistry) {
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().build(),
new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(),
observationRegistry);
ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry);
}
}

View File

@@ -1,372 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.chat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingHelper;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.util.CollectionUtils;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = OpenAiChatModelProxyToolCallsIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
class OpenAiChatModelProxyToolCallsIT {
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelProxyToolCallsIT.class);
private static final String DEFAULT_MODEL = "gpt-4o-mini";
FunctionCallback functionDefinition = new FunctionCallingHelper.FunctionDefinition("getWeatherInLocation",
"Get the weather in location", """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["C", "F"]
}
},
"required": ["location", "unit"]
}
""");
@Autowired
private OpenAiChatModel chatModel;
// Helper class that reuses some of the {@link AbstractToolCallSupport} functionality
// to help to implement the function call handling logic on the client side.
private FunctionCallingHelper functionCallingHelper = new FunctionCallingHelper();
@SuppressWarnings("unchecked")
private static Map<String, String> getFunctionArguments(String functionArguments) {
try {
return new ObjectMapper().readValue(functionArguments, Map.class);
}
catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
// Function which will be called by the AI model.
private String getWeatherInLocation(String location, String unit) {
double temperature = 0;
if (location.contains("Paris")) {
temperature = 15;
}
else if (location.contains("Tokyo")) {
temperature = 10;
}
else if (location.contains("San Francisco")) {
temperature = 30;
}
return String.format("The weather in %s is %s%s", location, temperature, unit);
}
@Test
void functionCall() throws JsonMappingException, JsonProcessingException {
List<Message> messages = List
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
var prompt = new Prompt(messages, promptOptions);
boolean isToolCall = false;
ChatResponse chatResponse = null;
do {
chatResponse = this.chatModel.call(prompt);
// We will have to convert the chatResponse into OpenAI assistant message.
// Note that the tool call check could be platform specific because the finish
// reasons.
isToolCall = this.functionCallingHelper.isToolCall(chatResponse,
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()));
if (isToolCall) {
Optional<Generation> toolCallGeneration = chatResponse.getResults()
.stream()
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
.findFirst();
assertThat(toolCallGeneration).isNotEmpty();
AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
var functionName = toolCall.name();
assertThat(functionName).isEqualTo("getWeatherInLocation");
String functionArguments = toolCall.arguments();
@SuppressWarnings("unchecked")
Map<String, String> argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class);
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
argumentsMap.get("unit").toString());
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName,
ModelOptionsUtils.toJsonString(functionResponse)));
}
ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of());
List<Message> toolCallConversation = this.functionCallingHelper
.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse);
assertThat(toolCallConversation).isNotEmpty();
prompt = new Prompt(toolCallConversation, prompt.getOptions());
}
}
while (isToolCall);
logger.info("Response: {}", chatResponse);
assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15");
}
@Test
void functionStream() throws JsonMappingException, JsonProcessingException {
List<Message> messages = List
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
var prompt = new Prompt(messages, promptOptions);
String response = processToolCall(prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> {
var functionName = toolCall.name();
assertThat(functionName).isEqualTo("getWeatherInLocation");
String functionArguments = toolCall.arguments();
Map<String, String> argumentsMap = getFunctionArguments(functionArguments);
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
argumentsMap.get("unit").toString());
return functionResponse;
})
.collectList()
.block()
.stream()
.map(cr -> cr.getResult().getOutput().getText())
.collect(Collectors.joining());
logger.info("Response: {}", response);
assertThat(response).contains("30", "10", "15");
}
private Flux<ChatResponse> processToolCall(Prompt prompt, Set<String> finishReasons,
Function<AssistantMessage.ToolCall, String> customFunction) {
Flux<ChatResponse> chatResponses = this.chatModel.stream(prompt);
return chatResponses.flatMap(chatResponse -> {
boolean isToolCall = this.functionCallingHelper.isToolCall(chatResponse, finishReasons);
if (isToolCall) {
Optional<Generation> toolCallGeneration = chatResponse.getResults()
.stream()
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
.findFirst();
assertThat(toolCallGeneration).isNotEmpty();
AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
String functionResponse = customFunction.apply(toolCall);
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(),
ModelOptionsUtils.toJsonString(functionResponse)));
}
ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of());
List<Message> toolCallConversation = this.functionCallingHelper
.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse);
assertThat(toolCallConversation).isNotEmpty();
var prompt2 = new Prompt(toolCallConversation, prompt.getOptions());
return processToolCall(prompt2, finishReasons, customFunction);
}
return Flux.just(chatResponse);
});
}
@Test
void functionCall2() throws JsonMappingException, JsonProcessingException {
List<Message> messages = List
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
var prompt = new Prompt(messages, promptOptions);
ChatResponse chatResponse = this.functionCallingHelper.processCall(this.chatModel, prompt,
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()),
toolCall -> {
var functionName = toolCall.name();
assertThat(functionName).isEqualTo("getWeatherInLocation");
String functionArguments = toolCall.arguments();
Map<String, String> argumentsMap = getFunctionArguments(functionArguments);
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
argumentsMap.get("unit").toString());
return functionResponse;
});
logger.info("Response: {}", chatResponse);
assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15");
}
@Test
void functionStream2() throws JsonMappingException, JsonProcessingException {
List<Message> messages = List
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
var prompt = new Prompt(messages, promptOptions);
Flux<ChatResponse> responses = this.functionCallingHelper.processStream(this.chatModel, prompt,
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()),
toolCall -> {
var functionName = toolCall.name();
assertThat(functionName).isEqualTo("getWeatherInLocation");
String functionArguments = toolCall.arguments();
Map<String, String> argumentsMap = getFunctionArguments(functionArguments);
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
argumentsMap.get("unit").toString());
return functionResponse;
});
String response = responses.collectList()
.block()
.stream()
.map(cr -> cr.getResult().getOutput().getText())
.collect(Collectors.joining());
logger.info("Response: {}", response);
assertThat(response).contains("30", "10", "15");
}
@SpringBootConfiguration
static class Config {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi, List<FunctionCallback> toolFunctionCallbacks) {
// enable the proxy tool calls option.
var options = OpenAiChatOptions.builder().model(DEFAULT_MODEL).proxyToolCalls(true).build();
return new OpenAiChatModel(openAiApi, options, null, toolFunctionCallbacks,
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
}
}
}

View File

@@ -234,12 +234,12 @@ public class OpenAiChatModelResponseFormatIT {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi);
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
}
}

View File

@@ -18,6 +18,7 @@ package org.springframework.ai.openai.chat;
import java.time.Duration;
import org.hamcrest.core.StringContains;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
@@ -133,7 +134,7 @@ public class OpenAiChatModelWithChatResponseMetadataTests {
httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358");
httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms");
this.server.expect(requestTo("/v1/chat/completions"))
this.server.expect(requestTo(StringContains.containsString("/v1/chat/completions")))
.andExpect(method(HttpMethod.POST))
.andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY))
.andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders));
@@ -169,12 +170,16 @@ public class OpenAiChatModelWithChatResponseMetadataTests {
@Bean
public OpenAiApi chatCompletionApi(RestClient.Builder builder, WebClient.Builder webClientBuilder) {
return new OpenAiApi("", TEST_API_KEY, builder, webClientBuilder);
return OpenAiApi.builder()
.apiKey(TEST_API_KEY)
.restClientBuilder(builder)
.webClientBuilder(webClientBuilder)
.build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi);
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
}
}

View File

@@ -55,7 +55,10 @@ public class OpenAiCompatibleChatModelIT {
static Stream<ChatModel> openAiCompatibleApis() {
Stream.Builder<ChatModel> builder = Stream.builder();
builder.add(new OpenAiChatModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")), forModelName("gpt-3.5-turbo")));
builder.add(OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build())
.defaultOptions(forModelName("gpt-3.5-turbo"))
.build());
// (26.01.2025) Disable because the Groq API is down. TODO: Re-enable when the API
// is back up.
@@ -66,9 +69,13 @@ public class OpenAiCompatibleChatModelIT {
// }
if (System.getenv("OPEN_ROUTER_API_KEY") != null) {
builder.add(new OpenAiChatModel(
new OpenAiApi("https://openrouter.ai/api", System.getenv("OPEN_ROUTER_API_KEY")),
forModelName("meta-llama/llama-3-8b-instruct")));
builder.add(OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl("https://openrouter.ai/api")
.apiKey(System.getenv("OPEN_ROUTER_API_KEY"))
.build())
.defaultOptions(forModelName("meta-llama/llama-3-8b-instruct"))
.build());
}
return builder.build();

View File

@@ -16,11 +16,13 @@
package org.springframework.ai.openai.chat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
@@ -34,19 +36,28 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Description;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.ParameterizedTypeReference;
import static org.assertj.core.api.Assertions.assertThat;
@@ -71,7 +82,7 @@ public class OpenAiPaymentTransactionIT {
public void transactionPaymentStatuses(String functionName) {
List<TransactionStatusResponse> content = this.chatClient.prompt()
.advisors(new LoggingAdvisor())
.functions(functionName)
.tools(functionName)
.user("""
What is the status of my payment transactions 001, 002 and 003?
""")
@@ -102,7 +113,7 @@ public class OpenAiPaymentTransactionIT {
Flux<String> flux = this.chatClient.prompt()
.advisors(new LoggingAdvisor())
.functions(functionName)
.tools(functionName)
.user(u -> u.text("""
What is the status of my payment transactions 001, 002 and 003?
@@ -217,25 +228,56 @@ public class OpenAiPaymentTransactionIT {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi, FunctionCallbackResolver functionCallbackResolver) {
return new OpenAiChatModel(openAiApi,
OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI.getName()).temperature(0.1).build(),
functionCallbackResolver, RetryUtils.DEFAULT_RETRY_TEMPLATE);
public OpenAiChatModel openAiClient(OpenAiApi openAiApi, ToolCallingManager toolCallingManager) {
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.toolCallingManager(toolCallingManager)
.defaultOptions(
OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI.getName()).temperature(0.1).build())
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
.build();
}
/**
* Because of the OPEN_API_SCHEMA type, the FunctionCallbackResolver instance must
* different from the other JSON schema types.
*/
@Bean
public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
manager.setApplicationContext(context);
return manager;
@ConditionalOnMissingBean
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
List<FunctionCallback> functionCallbacks, List<ToolCallbackProvider> tcbProviders) {
List<FunctionCallback> allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks);
tcbProviders.stream()
.map(pr -> List.of(pr.getToolCallbacks()))
.forEach(allFunctionAndToolCallbacks::addAll);
var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks);
var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder()
.applicationContext(applicationContext)
.build();
return new DelegatingToolCallbackResolver(
List.of(staticToolCallbackResolver, springBeanToolCallbackResolver));
}
@Bean
@ConditionalOnMissingBean
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() {
return new DefaultToolExecutionExceptionProcessor(false);
}
@Bean
@ConditionalOnMissingBean
ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver,
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor,
ObjectProvider<ObservationRegistry> observationRegistry) {
return ToolCallingManager.builder()
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
.toolCallbackResolver(toolCallbackResolver)
.toolExecutionExceptionProcessor(toolExecutionExceptionProcessor)
.build();
}
}

View File

@@ -110,8 +110,11 @@ public class OpenAiRetryTests {
this.retryListener = new TestRetryListener();
this.retryTemplate.registerListener(this.retryListener);
this.chatModel = new OpenAiChatModel(this.openAiApi, OpenAiChatOptions.builder().build(), null,
this.retryTemplate);
this.chatModel = OpenAiChatModel.builder()
.openAiApi(this.openAiApi)
.defaultOptions(OpenAiChatOptions.builder().build())
.retryTemplate(this.retryTemplate)
.build();
this.embeddingModel = new OpenAiEmbeddingModel(this.openAiApi, MetadataMode.EMBED,
OpenAiEmbeddingOptions.builder().build(), this.retryTemplate);
this.audioTranscriptionModel = new OpenAiAudioTranscriptionModel(this.openAiAudioApi,

View File

@@ -330,12 +330,15 @@ class DeepSeekWithOpenAiChatModelIT {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(DEEPSEEK_BASE_URL, System.getenv("DEEPSEEK_API_KEY"));
return OpenAiApi.builder().baseUrl(DEEPSEEK_BASE_URL).apiKey(System.getenv("DEEPSEEK_API_KEY")).build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_DEEPSEEK_MODEL).build());
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_DEEPSEEK_MODEL).build())
.build();
}
}

View File

@@ -383,12 +383,15 @@ class GroqWithOpenAiChatModelIT {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(GROQ_BASE_URL, System.getenv("GROQ_API_KEY"));
return OpenAiApi.builder().baseUrl(GROQ_BASE_URL).apiKey(System.getenv("GROQ_API_KEY")).build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_GROQ_MODEL).build());
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_GROQ_MODEL).build())
.build();
}
}

View File

@@ -318,13 +318,15 @@ class NvidiaWithOpenAiChatModelIT {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(NVIDIA_BASE_URL, System.getenv("NVIDIA_API_KEY"));
return OpenAiApi.builder().baseUrl(NVIDIA_BASE_URL).apiKey(System.getenv("NVIDIA_API_KEY")).build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi,
OpenAiChatOptions.builder().maxTokens(2048).model(DEFAULT_NVIDIA_MODEL).build());
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(OpenAiChatOptions.builder().maxTokens(2048).model(DEFAULT_NVIDIA_MODEL).build())
.build();
}
}

View File

@@ -327,14 +327,20 @@ class PerplexityWithOpenAiChatModelIT {
@Bean
public OpenAiApi chatCompletionApi() {
return new OpenAiApi(PERPLEXITY_BASE_URL, System.getenv("PERPLEXITY_API_KEY"), PERPLEXITY_COMPLETIONS_PATH,
"/v1/embeddings", RestClient.builder(), WebClient.builder(),
RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
return OpenAiApi.builder()
.baseUrl(PERPLEXITY_BASE_URL)
.apiKey(System.getenv("PERPLEXITY_API_KEY"))
.completionsPath(PERPLEXITY_COMPLETIONS_PATH)
.embeddingsPath("/v1/embeddings")
.build();
}
@Bean
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_PERPLEXITY_MODEL).build());
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_PERPLEXITY_MODEL).build())
.build();
}
}

View File

@@ -104,7 +104,7 @@ public class OpenAiEmbeddingModelObservationIT {
@Bean
public OpenAiApi openAiApi() {
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
}
@Bean

View File

@@ -163,13 +163,12 @@ public class MetadataTransformerIT {
throw new IllegalArgumentException(
"You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY");
}
return new OpenAiApi(apiKey);
return OpenAiApi.builder().apiKey(apiKey).build();
}
@Bean
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) {
OpenAiChatModel openAiChatModel = new OpenAiChatModel(openAiApi);
return openAiChatModel;
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
}
@Bean

View File

@@ -1,235 +0,0 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vertexai.gemini.tool;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallback.SchemaType;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Description;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
*/
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
@Deprecated
public class VertexAiGeminiPaymentTransactionDeprecatedIT {
private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionDeprecatedIT.class);
private static final Map<Transaction, Status> DATASET = Map.of(new Transaction("001"), new Status("pending"),
new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected"));
@Autowired
ChatClient chatClient;
@Test
public void paymentStatuses() {
// @formatter:off
String content = this.chatClient.prompt()
.advisors(new LoggingAdvisor())
.tools("paymentStatus")
.user("""
What is the status of my payment transactions 001, 002 and 003?
If requred invoke the function per transaction.
""").call().content();
logger.info("" + content);
assertThat(content).contains("001", "002", "003");
assertThat(content).contains("pending", "approved", "rejected");
}
@RepeatedTest(5)
public void streamingPaymentStatuses() {
Flux<String> streamContent = this.chatClient.prompt()
.advisors(new LoggingAdvisor())
.tools("paymentStatus")
.user("""
What is the status of my payment transactions 001, 002 and 003?
If requred invoke the function per transaction.
""")
.stream()
.content();
String content = streamContent.collectList().block().stream().collect(Collectors.joining());
logger.info(content);
assertThat(content).contains("001", "002", "003");
assertThat(content).contains("pending", "approved", "rejected");
// Quota rate
try {
Thread.sleep(1000);
}
catch (InterruptedException e) {
}
}
record TransactionStatusResponse(String id, String status) {
}
private static class LoggingAdvisor implements CallAroundAdvisor {
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
@Override
public String getName() {
return this.getClass().getSimpleName();
}
@Override
public int getOrder() {
return 0;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
var response = chain.nextAroundCall(before(advisedRequest));
observeAfter(response);
return response;
}
private AdvisedRequest before(AdvisedRequest request) {
logger.info("System text: \n" + request.systemText());
logger.info("System params: " + request.systemParams());
logger.info("User text: \n" + request.userText());
logger.info("User params:" + request.userParams());
logger.info("Function names: " + request.functionNames());
logger.info("Options: " + request.chatOptions().toString());
return request;
}
private void observeAfter(AdvisedResponse advisedResponse) {
logger.info("Response: " + advisedResponse.response());
}
}
record Transaction(String id) {
}
record Status(String name) {
}
record Transactions(List<Transaction> transactions) {
}
record Statuses(List<Status> statuses) {
}
@SpringBootConfiguration
public static class TestConfiguration {
@Bean
@Description("Get the status of a single payment transaction")
public Function<Transaction, Status> paymentStatus() {
return transaction -> {
logger.info("Single Transaction: " + transaction);
return DATASET.get(transaction);
};
}
@Bean
@Description("Get the list statuses of a list of payment transactions")
public Function<Transactions, Statuses> paymentStatuses() {
return transactions -> {
logger.info("Transactions: " + transactions);
return new Statuses(transactions.transactions().stream().map(t -> DATASET.get(t)).toList());
};
}
@Bean
public ChatClient chatClient(VertexAiGeminiChatModel chatModel) {
return ChatClient.builder(chatModel).build();
}
@Bean
public VertexAI vertexAiApi() {
String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");
String location = System.getenv("VERTEX_AI_GEMINI_LOCATION");
return new VertexAI.Builder().setLocation(location)
.setProjectId(projectId)
.setTransport(Transport.REST)
// .setTransport(Transport.GRPC)
.build();
}
@Bean
public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ApplicationContext context) {
FunctionCallbackResolver functionCallbackResolver = springAiFunctionManager(context);
return new VertexAiGeminiChatModel(vertexAi,
VertexAiGeminiChatOptions.builder()
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
.temperature(0.1)
.build(),
functionCallbackResolver);
}
/**
* Because of the OPEN_API_SCHEMA type, the FunctionCallbackResolver instance
* must
* different from the other JSON schema types.
*/
private FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
manager.setSchemaType(SchemaType.OPEN_API_SCHEMA);
manager.setApplicationContext(context);
return manager;
}
}
}

View File

@@ -30,19 +30,13 @@ import com.fasterxml.jackson.annotation.JsonPropertyOrder;
* @author Ilayaperumal Gopinathan
* @since 1.0.0
*/
@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "generationTokens", "nativeUsage" })
@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "nativeUsage" })
public class DefaultUsage implements Usage {
private final Integer promptTokens;
private final Integer completionTokens;
/**
* @deprecated as of 1.0.0-M6, scheduled for removal
*/
@Deprecated(forRemoval = true, since = "1.0.0-M6")
private final Long generationTokens;
private final int totalTokens;
private final Object nativeUsage;
@@ -62,7 +56,6 @@ public class DefaultUsage implements Usage {
public DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens, Object nativeUsage) {
this.promptTokens = promptTokens != null ? promptTokens : 0;
this.completionTokens = completionTokens != null ? completionTokens : 0;
this.generationTokens = Long.valueOf(this.completionTokens);
this.totalTokens = totalTokens != null ? totalTokens
: calculateTotalTokens(this.promptTokens, this.completionTokens);
this.nativeUsage = nativeUsage;
@@ -106,11 +99,8 @@ public class DefaultUsage implements Usage {
@JsonCreator
public static DefaultUsage fromJson(@JsonProperty("promptTokens") Integer promptTokens,
@JsonProperty("completionTokens") Integer completionTokens,
@JsonProperty("generationTokens") Long generationTokens, @JsonProperty("totalTokens") Integer totalTokens,
@JsonProperty("nativeUsage") Object nativeUsage) {
Integer effectiveCompletionTokens = completionTokens != null ? completionTokens
: (generationTokens != null ? generationTokens.intValue() : 0);
return new DefaultUsage(promptTokens, effectiveCompletionTokens, totalTokens, nativeUsage);
@JsonProperty("totalTokens") Integer totalTokens, @JsonProperty("nativeUsage") Object nativeUsage) {
return new DefaultUsage(promptTokens, completionTokens, totalTokens, nativeUsage);
}
@Override

View File

@@ -34,11 +34,6 @@ public interface Usage {
*/
Integer getPromptTokens();
@Deprecated(forRemoval = true, since = "1.0.0-M6")
default Long getGenerationTokens() {
return getCompletionTokens().longValue();
}
/**
* Returns the number of tokens returned in the {@literal generation (aka completion)}
* of the AI's response.

View File

@@ -32,13 +32,12 @@ public class DefaultUsageTests {
void testSerializationWithAllFields() throws Exception {
DefaultUsage usage = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150));
String json = this.objectMapper.writeValueAsString(usage);
assertThat(json)
.isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}");
assertThat(json).isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}");
}
@Test
void testDeserializationWithAllFields() throws Exception {
String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}";
String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}";
DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class);
assertThat(usage.getPromptTokens()).isEqualTo(100);
assertThat(usage.getCompletionTokens()).isEqualTo(50);
@@ -49,8 +48,7 @@ public class DefaultUsageTests {
void testSerializationWithNullFields() throws Exception {
DefaultUsage usage = new DefaultUsage((Integer) null, (Integer) null, (Integer) null);
String json = this.objectMapper.writeValueAsString(usage);
assertThat(json)
.isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}");
assertThat(json).isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0}");
}
@Test
@@ -92,8 +90,7 @@ public class DefaultUsageTests {
// Test serialization
String json = this.objectMapper.writeValueAsString(usage);
assertThat(json)
.isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}");
assertThat(json).isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}");
// Test deserialization
DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class);
@@ -113,8 +110,7 @@ public class DefaultUsageTests {
// Test serialization
String json = this.objectMapper.writeValueAsString(usage);
assertThat(json)
.isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}");
assertThat(json).isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0}");
// Test deserialization
DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class);
@@ -123,18 +119,9 @@ public class DefaultUsageTests {
assertThat(deserializedUsage.getTotalTokens()).isEqualTo(0);
}
@Test
void testDeserializationWithLegacyFormat() throws Exception {
String json = "{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}";
DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class);
assertThat(usage.getPromptTokens()).isEqualTo(100);
assertThat(usage.getCompletionTokens()).isEqualTo(50);
assertThat(usage.getTotalTokens()).isEqualTo(150);
}
@Test
void testDeserializationWithDifferentPropertyOrder() throws Exception {
String json = "{\"totalTokens\":150,\"generationTokens\":50,\"completionTokens\":50,\"promptTokens\":100}";
String json = "{\"totalTokens\":150,\"completionTokens\":50,\"promptTokens\":100}";
DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class);
assertThat(usage.getPromptTokens()).isEqualTo(100);
assertThat(usage.getCompletionTokens()).isEqualTo(50);
@@ -150,7 +137,7 @@ public class DefaultUsageTests {
DefaultUsage usage = new DefaultUsage(100, 50, 150, customNativeUsage);
String json = this.objectMapper.writeValueAsString(usage);
assertThat(json).isEqualTo(
"{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50,\"nativeUsage\":{\"custom_field\":\"custom_value\",\"custom_number\":42}}");
"{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"nativeUsage\":{\"custom_field\":\"custom_value\",\"custom_number\":42}}");
}
@Test
@@ -195,14 +182,6 @@ public class DefaultUsageTests {
assertThat(deserializedMap.get("field5")).isEqualTo(java.util.Map.of("nested", "value"));
}
@Test
@SuppressWarnings("deprecation")
void testDeprecatedGenerationTokens() {
DefaultUsage usage = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150));
assertThat(usage.getGenerationTokens()).isEqualTo(50L);
assertThat(usage.getCompletionTokens().longValue()).isEqualTo(usage.getGenerationTokens());
}
@Test
void testEqualsAndHashCode() {
DefaultUsage usage1 = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150));
@@ -262,8 +241,7 @@ public class DefaultUsageTests {
assertThat(usage.getTotalTokens()).isEqualTo(-3);
String json = this.objectMapper.writeValueAsString(usage);
assertThat(json)
.isEqualTo("{\"promptTokens\":-1,\"completionTokens\":-2,\"totalTokens\":-3,\"generationTokens\":-2}");
assertThat(json).isEqualTo("{\"promptTokens\":-1,\"completionTokens\":-2,\"totalTokens\":-3}");
}
@Test

View File

@@ -73,18 +73,6 @@ public class ToolCallingManagerTests {
runExplicitToolCallingExecutionWithOptions(chatOptions, prompt);
}
@Test
void explicitToolCallingExecutionWithLegacyOptions() {
ChatOptions chatOptions = FunctionCallingOptions.builder()
.functionCallbacks(ToolCallbacks.from(tools))
.proxyToolCalls(true)
.build();
Prompt prompt = new Prompt(
new UserMessage("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")),
chatOptions);
runExplicitToolCallingExecutionWithOptions(chatOptions, prompt);
}
@Test
void explicitToolCallingExecutionWithNewOptionsStream() {
ChatOptions chatOptions = ToolCallingChatOptions.builder()

View File

@@ -35,7 +35,7 @@ public class MistralAiChatProperties extends MistralAiParentProperties {
public static final String CONFIG_PREFIX = "spring.ai.mistralai.chat";
public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue();
public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.SMALL.getValue();
private static final Double DEFAULT_TEMPERATURE = 0.7;

View File

@@ -35,6 +35,7 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -97,11 +98,9 @@ class FunctionCallbackWithPlainFunctionBeanIT {
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
.function("weatherFunction")
.build();
ToolCallingChatOptions toolOptions = ToolCallingChatOptions.builder().toolNames("weatherFunction").build();
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), toolOptions));
logger.info("Response: {}", response);
});

View File

@@ -34,6 +34,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.moonshot.MoonshotChatModel;
import org.springframework.ai.moonshot.MoonshotChatOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -98,11 +99,9 @@ class FunctionCallbackWithPlainFunctionBeanIT {
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius");
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
.function("weatherFunction")
.build();
ToolCallingChatOptions toolOptions = ToolCallingChatOptions.builder().toolNames("weatherFunction").build();
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), toolOptions));
logger.info("Response: {}", response);
});

View File

@@ -17,6 +17,7 @@
package org.springframework.ai.autoconfigure.vertexai.gemini.tool;
import java.util.List;
import java.util.Set;
import java.util.function.Function;
import org.junit.jupiter.api.Test;
@@ -29,6 +30,7 @@ import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -109,14 +111,14 @@ class FunctionCallWithFunctionBeanIT {
""");
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
FunctionCallingOptions.builder().function("weatherFunction").build()));
ToolCallingChatOptions.builder().toolNames("weatherFunction").build()));
logger.info("Response: {}", response);
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
response = chatModel.call(new Prompt(List.of(userMessage),
VertexAiGeminiChatOptions.builder().function("weatherFunction3").build()));
VertexAiGeminiChatOptions.builder().toolNames(Set.of("weatherFunction3")).build()));
logger.info("Response: {}", response);

View File

@@ -34,6 +34,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
@@ -97,8 +98,8 @@ class FunctionCallbackWithPlainFunctionBeanIT {
UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
.function("weatherFunction")
ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder()
.toolNames("weatherFunction")
.build();
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));

View File

@@ -100,7 +100,7 @@ class MongoDbAtlasLocalContainerConnectionDetailsFactoryIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -121,7 +121,7 @@ public class BasicAuthChromaWhereIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -279,7 +279,7 @@ public class ChromaVectorStoreIT extends BaseVectorStoreTests {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -186,7 +186,7 @@ public class ChromaVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -158,7 +158,7 @@ public class TokenSecuredChromaWhereIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -505,7 +505,7 @@ class ElasticsearchVectorStoreIT extends BaseVectorStoreTests {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
@Bean

View File

@@ -220,7 +220,7 @@ public class ElasticsearchVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
@Bean

View File

@@ -128,7 +128,7 @@ public class HanaCloudVectorStoreIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -206,7 +206,7 @@ public class HanaVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -252,7 +252,7 @@ public class MariaDBStoreCustomNamesIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -451,7 +451,7 @@ public class MariaDBStoreIT extends BaseVectorStoreTests {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -203,7 +203,7 @@ public class MariaDBStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -252,7 +252,7 @@ class MilvusVectorStoreCustomFieldNamesIT {
@Bean
EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -361,7 +361,7 @@ public class MilvusVectorStoreIT extends BaseVectorStoreTests {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
// return new OpenAiEmbeddingModel(new
// OpenAiApi(System.getenv("OPENAI_API_KEY")), MetadataMode.EMBED,
// OpenAiEmbeddingOptions.builder().withModel("text-embedding-ada-002").build());

View File

@@ -190,7 +190,7 @@ public class MilvusVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -345,7 +345,7 @@ class MongoDBAtlasVectorStoreIT extends BaseVectorStoreTests {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
@Bean

View File

@@ -209,7 +209,7 @@ public class MongoDbVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
@Bean

View File

@@ -374,7 +374,7 @@ class Neo4jVectorStoreIT extends BaseVectorStoreTests {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -194,7 +194,7 @@ public class Neo4jVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -602,7 +602,7 @@ class OpenSearchVectorStoreIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -226,7 +226,7 @@ public class OpenSearchVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -245,7 +245,7 @@ public class PgVectorStoreCustomNamesIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -483,7 +483,7 @@ public class PgVectorStoreIT extends BaseVectorStoreTests {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}

View File

@@ -214,7 +214,7 @@ public class PgVectorStoreObservationIT {
@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
}
}