Remove deprecations from 1.0.0-M6
- Remove deprecations from models, vector stores and usage - Deprecations from FunctionCallback and ObservationContext/Convention will be in a separate PR Models updates - Remove AbstractToolCallSupport from the models which use ToolCallingManager - Remove deprecated constructors and their usage - Remove FunctionCallbackResolver and FunctionCallbacks usage in the models - Add back deprecations for VectorStoreChatMemoryAdvisor until builder is fixed - Update OpenAiPaymentTransactionIT to use ToolCallbackResolver in config Signed-off-by: Ilayaperumal Gopinathan <ilayaperumal.gopinathan@broadcom.com>
This commit is contained in:
committed by
Mark Pollack
parent
a1e417f350
commit
ded9facfe5
@@ -30,13 +30,6 @@ import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.model.tool.LegacyToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.util.json.JsonParser;
|
||||
import org.springframework.lang.Nullable;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
@@ -59,7 +52,6 @@ import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.metadata.UsageUtils;
|
||||
import org.springframework.ai.chat.model.AbstractToolCallSupport;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -72,10 +64,12 @@ import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.util.json.JsonParser;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
@@ -94,7 +88,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Alexandros Pappas
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
|
||||
public class AnthropicChatModel implements ChatModel {
|
||||
|
||||
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getValue();
|
||||
|
||||
@@ -135,111 +129,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
|
||||
*/
|
||||
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @deprecated Use {@link AnthropicChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi) {
|
||||
this(anthropicApi,
|
||||
AnthropicChatOptions.builder()
|
||||
.model(DEFAULT_MODEL_NAME)
|
||||
.maxTokens(DEFAULT_MAX_TOKENS)
|
||||
.temperature(DEFAULT_TEMPERATURE)
|
||||
.build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
* @deprecated Use {@link AnthropicChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) {
|
||||
this(anthropicApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
* @param retryTemplate the retry template used to retry the Anthropic API calls.
|
||||
* @deprecated Use {@link AnthropicChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
RetryTemplate retryTemplate) {
|
||||
this(anthropicApi, defaultOptions, retryTemplate, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
* @param retryTemplate the retry template used to retry the Anthropic API calls.
|
||||
* @param functionCallbackResolver the function callback resolver used to resolve the
|
||||
* function by its name.
|
||||
* @deprecated Use {@link AnthropicChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver) {
|
||||
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, List.of());
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
* @param retryTemplate the retry template used to retry the Anthropic API calls.
|
||||
* @param functionCallbackResolver the function callback resolver used to resolve the
|
||||
* function by its name.
|
||||
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
|
||||
* calls.
|
||||
* @deprecated Use {@link AnthropicChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver,
|
||||
List<FunctionCallback> toolFunctionCallbacks) {
|
||||
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, toolFunctionCallbacks,
|
||||
ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
* @param retryTemplate the retry template used to retry the Anthropic API calls.
|
||||
* @param functionCallbackResolver the function callback resolver used to resolve the
|
||||
* function by its name.
|
||||
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
|
||||
* calls.
|
||||
* @deprecated Use {@link AnthropicChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
RetryTemplate retryTemplate, @Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry) {
|
||||
this(anthropicApi, defaultOptions,
|
||||
LegacyToolCallingManager.builder()
|
||||
.functionCallbackResolver(functionCallbackResolver)
|
||||
.functionCallbacks(toolFunctionCallbacks)
|
||||
.build(),
|
||||
retryTemplate, observationRegistry);
|
||||
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
|
||||
+ "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
|
||||
}
|
||||
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate,
|
||||
ObservationRegistry observationRegistry) {
|
||||
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
|
||||
// because it modifies them. We are using ToolCallingManager instead,
|
||||
// so we just pass empty options here.
|
||||
super(null, AnthropicChatOptions.builder().build(), List.of());
|
||||
|
||||
Assert.notNull(anthropicApi, "anthropicApi cannot be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
|
||||
@@ -488,10 +380,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
|
||||
AnthropicChatOptions.class);
|
||||
}
|
||||
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
|
||||
AnthropicChatOptions.class);
|
||||
}
|
||||
else {
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
|
||||
AnthropicChatOptions.class);
|
||||
@@ -648,10 +536,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
|
||||
|
||||
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
|
||||
|
||||
private FunctionCallbackResolver functionCallbackResolver;
|
||||
|
||||
private List<FunctionCallback> toolCallbacks;
|
||||
|
||||
private ToolCallingManager toolCallingManager;
|
||||
|
||||
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
@@ -679,18 +563,6 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
|
||||
this.functionCallbackResolver = functionCallbackResolver;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
|
||||
this.toolCallbacks = toolCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder observationRegistry(ObservationRegistry observationRegistry) {
|
||||
this.observationRegistry = observationRegistry;
|
||||
return this;
|
||||
@@ -698,22 +570,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
|
||||
|
||||
public AnthropicChatModel build() {
|
||||
if (toolCallingManager != null) {
|
||||
Assert.isNull(functionCallbackResolver,
|
||||
"functionCallbackResolver cannot be set when toolCallingManager is set");
|
||||
Assert.isNull(toolCallbacks, "toolCallbacks cannot be set when toolCallingManager is set");
|
||||
|
||||
return new AnthropicChatModel(anthropicApi, defaultOptions, toolCallingManager, retryTemplate,
|
||||
observationRegistry);
|
||||
}
|
||||
if (functionCallbackResolver != null) {
|
||||
Assert.isNull(toolCallingManager,
|
||||
"toolCallingManager cannot be set when functionCallbackResolver is set");
|
||||
List<FunctionCallback> toolCallbacks = this.toolCallbacks != null ? this.toolCallbacks : List.of();
|
||||
|
||||
return new AnthropicChatModel(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver,
|
||||
toolCallbacks, observationRegistry);
|
||||
}
|
||||
|
||||
return new AnthropicChatModel(anthropicApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
@@ -34,6 +34,8 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation
|
||||
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
|
||||
import org.springframework.ai.model.tool.DefaultToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.observation.conventions.AiOperationType;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -171,8 +173,7 @@ public class AnthropicChatModelObservationIT {
|
||||
public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi,
|
||||
TestObservationRegistry observationRegistry) {
|
||||
return new AnthropicChatModel(anthropicApi, AnthropicChatOptions.builder().build(),
|
||||
RetryTemplate.defaultInstance(), new DefaultFunctionCallbackResolver(), List.of(),
|
||||
observationRegistry);
|
||||
ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -73,7 +73,6 @@ import org.springframework.ai.chat.metadata.PromptMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.metadata.UsageUtils;
|
||||
import org.springframework.ai.chat.model.AbstractToolCallSupport;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -86,16 +85,12 @@ import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.LegacyToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -120,7 +115,7 @@ import org.springframework.util.CollectionUtils;
|
||||
* @see com.azure.ai.openai.OpenAIClient
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel {
|
||||
public class AzureOpenAiChatModel implements ChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModel.class);
|
||||
|
||||
@@ -162,53 +157,8 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
|
||||
*/
|
||||
private final ToolCallingManager toolCallingManager;
|
||||
|
||||
@Deprecated
|
||||
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
|
||||
this(openAIClientBuilder,
|
||||
AzureOpenAiChatOptions.builder()
|
||||
.deploymentName(DEFAULT_DEPLOYMENT_NAME)
|
||||
.temperature(DEFAULT_TEMPERATURE)
|
||||
.build());
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options) {
|
||||
this(openAIClientBuilder, options, null);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
|
||||
FunctionCallbackResolver functionCallbackResolver) {
|
||||
this(openAIClientBuilder, options, functionCallbackResolver, List.of());
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks) {
|
||||
this(openAIClientBuilder, options, functionCallbackResolver, toolFunctionCallbacks, ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry) {
|
||||
this(openAIClientBuilder, options,
|
||||
LegacyToolCallingManager.builder()
|
||||
.functionCallbackResolver(functionCallbackResolver)
|
||||
.functionCallbacks(toolFunctionCallbacks)
|
||||
.build(),
|
||||
observationRegistry);
|
||||
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
|
||||
+ "Please use the AzureOpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
|
||||
}
|
||||
|
||||
public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions defaultOptions,
|
||||
ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry) {
|
||||
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
|
||||
// because it modifies them. We are using ToolCallingManager instead,
|
||||
// so we just pass empty options here.
|
||||
super(null, AzureOpenAiChatOptions.builder().build(), List.of());
|
||||
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
|
||||
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
|
||||
@@ -534,10 +484,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
|
||||
|
||||
if (prompt.getOptions() != null) {
|
||||
AzureOpenAiChatOptions updatedRuntimeOptions;
|
||||
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
|
||||
FunctionCallingOptions.class, AzureOpenAiChatOptions.class);
|
||||
}
|
||||
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
|
||||
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions,
|
||||
ToolCallingChatOptions.class, AzureOpenAiChatOptions.class);
|
||||
@@ -668,10 +614,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
|
||||
AzureOpenAiChatOptions.class);
|
||||
}
|
||||
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
|
||||
AzureOpenAiChatOptions.class);
|
||||
}
|
||||
else {
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
|
||||
AzureOpenAiChatOptions.class);
|
||||
@@ -978,10 +920,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
|
||||
|
||||
private ToolCallingManager toolCallingManager;
|
||||
|
||||
private FunctionCallbackResolver functionCallbackResolver;
|
||||
|
||||
private List<FunctionCallback> toolFunctionCallbacks;
|
||||
|
||||
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
|
||||
private Builder() {
|
||||
@@ -1002,18 +940,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
|
||||
this.functionCallbackResolver = functionCallbackResolver;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
|
||||
this.toolFunctionCallbacks = toolFunctionCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder observationRegistry(ObservationRegistry observationRegistry) {
|
||||
this.observationRegistry = observationRegistry;
|
||||
return this;
|
||||
@@ -1021,29 +947,9 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha
|
||||
|
||||
public AzureOpenAiChatModel build() {
|
||||
if (toolCallingManager != null) {
|
||||
Assert.isNull(functionCallbackResolver,
|
||||
"functionCallbackResolver cannot be set when toolCallingManager is set");
|
||||
Assert.isNull(toolFunctionCallbacks,
|
||||
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
|
||||
|
||||
return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, toolCallingManager,
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
if (functionCallbackResolver != null) {
|
||||
Assert.isNull(toolCallingManager,
|
||||
"toolCallingManager cannot be set when functionCallbackResolver is set");
|
||||
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
|
||||
: List.of();
|
||||
|
||||
return new Builder().openAIClientBuilder(openAIClientBuilder)
|
||||
.defaultOptions(defaultOptions)
|
||||
.functionCallbackResolver(functionCallbackResolver)
|
||||
.toolFunctionCallbacks(toolCallbacks)
|
||||
.observationRegistry(observationRegistry)
|
||||
.build();
|
||||
}
|
||||
|
||||
return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
@@ -82,7 +82,6 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.model.AbstractToolCallSupport;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -95,10 +94,6 @@ import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.LegacyToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
@@ -135,7 +130,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Jihoon Kim
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class BedrockProxyChatModel extends AbstractToolCallSupport implements ChatModel {
|
||||
public class BedrockProxyChatModel implements ChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(BedrockProxyChatModel.class);
|
||||
|
||||
@@ -161,33 +156,10 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
|
||||
*/
|
||||
private ChatModelObservationConvention observationConvention;
|
||||
|
||||
/**
|
||||
* @deprecated Use
|
||||
* {@link #BedrockProxyChatModel(BedrockRuntimeClient, BedrockRuntimeAsyncClient, ToolCallingChatOptions, ObservationRegistry, ToolCallingManager)}
|
||||
* instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
|
||||
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, FunctionCallingOptions defaultOptions,
|
||||
FunctionCallbackResolver functionCallbackResolver, List<FunctionCallback> toolFunctionCallbacks,
|
||||
ObservationRegistry observationRegistry) {
|
||||
|
||||
this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, from(defaultOptions), observationRegistry,
|
||||
LegacyToolCallingManager.builder()
|
||||
.functionCallbackResolver(functionCallbackResolver)
|
||||
.functionCallbacks(toolFunctionCallbacks)
|
||||
.build());
|
||||
}
|
||||
|
||||
public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
|
||||
BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions,
|
||||
ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) {
|
||||
|
||||
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
|
||||
// because it modifies them. We are using ToolCallingManager instead,
|
||||
// so we just pass empty options here.
|
||||
super(null, FunctionCallingOptions.builder().build(), List.of());
|
||||
|
||||
Assert.notNull(bedrockRuntimeClient, "bedrockRuntimeClient must not be null");
|
||||
Assert.notNull(bedrockRuntimeAsyncClient, "bedrockRuntimeAsyncClient must not be null");
|
||||
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
|
||||
@@ -199,21 +171,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
|
||||
this.toolCallingManager = toolCallingManager;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
private static ToolCallingChatOptions from(FunctionCallingOptions options) {
|
||||
return ToolCallingChatOptions.builder()
|
||||
.model(options.getModel())
|
||||
.maxTokens(options.getMaxTokens())
|
||||
.stopSequences(options.getStopSequences())
|
||||
.temperature(options.getTemperature())
|
||||
.topP(options.getTopP())
|
||||
.toolCallbacks(options.getFunctionCallbacks())
|
||||
.toolNames(options.getFunctions())
|
||||
.internalToolExecutionEnabled(options.getProxyToolCalls() != null ? !options.getProxyToolCalls() : false)
|
||||
.toolContext(options.getToolContext())
|
||||
.build();
|
||||
}
|
||||
|
||||
private static ToolCallingChatOptions from(ChatOptions options) {
|
||||
return ToolCallingChatOptions.builder()
|
||||
.model(options.getModel())
|
||||
@@ -295,9 +252,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
|
||||
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
|
||||
runtimeOptions = toolCallingChatOptions.copy();
|
||||
}
|
||||
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
runtimeOptions = from(functionCallingOptions);
|
||||
}
|
||||
else {
|
||||
runtimeOptions = from(prompt.getOptions());
|
||||
}
|
||||
@@ -804,10 +758,6 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
|
||||
|
||||
private ToolCallingChatOptions defaultOptions = ToolCallingChatOptions.builder().build();
|
||||
|
||||
private FunctionCallbackResolver functionCallbackResolver;
|
||||
|
||||
private List<FunctionCallback> toolFunctionCallbacks;
|
||||
|
||||
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
|
||||
private ChatModelObservationConvention customObservationConvention;
|
||||
@@ -824,160 +774,47 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #credentialsProvider(AwsCredentialsProvider)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withCredentialsProvider(AwsCredentialsProvider credentialsProvider) {
|
||||
Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null.");
|
||||
this.credentialsProvider = credentialsProvider;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder credentialsProvider(AwsCredentialsProvider credentialsProvider) {
|
||||
Assert.notNull(credentialsProvider, "'credentialsProvider' must not be null.");
|
||||
this.credentialsProvider = credentialsProvider;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #region(Region)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withRegion(Region region) {
|
||||
Assert.notNull(region, "'region' must not be null.");
|
||||
this.region = region;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder region(Region region) {
|
||||
Assert.notNull(region, "'region' must not be null.");
|
||||
this.region = region;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #timeout(Duration)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withTimeout(Duration timeout) {
|
||||
Assert.notNull(timeout, "'timeout' must not be null.");
|
||||
this.timeout = timeout;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder timeout(Duration timeout) {
|
||||
Assert.notNull(timeout, "'timeout' must not be null.");
|
||||
this.timeout = timeout;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #defaultOptions(ToolCallingChatOptions)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withDefaultOptions(FunctionCallingOptions defaultOptions) {
|
||||
Assert.notNull(defaultOptions, "'defaultOptions' must not be null.");
|
||||
return this.defaultOptions(ToolCallingChatOptions.builder()
|
||||
.model(defaultOptions.getModel())
|
||||
.maxTokens(defaultOptions.getMaxTokens())
|
||||
.stopSequences(defaultOptions.getStopSequences())
|
||||
.temperature(defaultOptions.getTemperature())
|
||||
.topP(defaultOptions.getTopP())
|
||||
.toolCallbacks(defaultOptions.getFunctionCallbacks())
|
||||
.toolNames(defaultOptions.getFunctions())
|
||||
.internalToolExecutionEnabled(
|
||||
defaultOptions.getProxyToolCalls() != null ? !defaultOptions.getProxyToolCalls() : false)
|
||||
.toolContext(defaultOptions.getToolContext())
|
||||
.build());
|
||||
}
|
||||
|
||||
public Builder defaultOptions(ToolCallingChatOptions defaultOptions) {
|
||||
Assert.notNull(defaultOptions, "'defaultOptions' must not be null.");
|
||||
this.defaultOptions = defaultOptions;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated To be removed after M6
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withFunctionCallbackContext(FunctionCallbackResolver functionCallbackResolver) {
|
||||
this.functionCallbackResolver = functionCallbackResolver;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
|
||||
this.functionCallbackResolver = functionCallbackResolver;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated To be removed after M6
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withToolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
|
||||
this.toolFunctionCallbacks = toolFunctionCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #observationRegistry(ObservationRegistry)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withObservationRegistry(ObservationRegistry observationRegistry) {
|
||||
Assert.notNull(observationRegistry, "'observationRegistry' must not be null.");
|
||||
this.observationRegistry = observationRegistry;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder observationRegistry(ObservationRegistry observationRegistry) {
|
||||
Assert.notNull(observationRegistry, "'observationRegistry' must not be null.");
|
||||
this.observationRegistry = observationRegistry;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use
|
||||
* {@link #customObservationConvention(ChatModelObservationConvention)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withCustomObservationConvention(ChatModelObservationConvention observationConvention) {
|
||||
Assert.notNull(observationConvention, "'observationConvention' must not be null.");
|
||||
this.customObservationConvention = observationConvention;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder customObservationConvention(ChatModelObservationConvention observationConvention) {
|
||||
Assert.notNull(observationConvention, "'observationConvention' must not be null.");
|
||||
this.customObservationConvention = observationConvention;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #bedrockRuntimeClient(BedrockRuntimeClient)} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withBedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder bedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) {
|
||||
this.bedrockRuntimeClient = bedrockRuntimeClient;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient)}
|
||||
* instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withBedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
|
||||
this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
|
||||
this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
|
||||
return this;
|
||||
@@ -1013,23 +850,11 @@ public class BedrockProxyChatModel extends AbstractToolCallSupport implements Ch
|
||||
BedrockProxyChatModel bedrockProxyChatModel = null;
|
||||
|
||||
if (this.toolCallingManager != null) {
|
||||
Assert.isNull(functionCallbackResolver,
|
||||
"functionCallbackResolver cannot be set when toolCallingManager is set");
|
||||
Assert.isNull(toolFunctionCallbacks,
|
||||
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
|
||||
|
||||
bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
|
||||
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry,
|
||||
this.toolCallingManager);
|
||||
|
||||
}
|
||||
else if (this.functionCallbackResolver != null) {
|
||||
Assert.isNull(toolCallingManager,
|
||||
"toolCallingManager cannot be set when functionCallbackResolver is set");
|
||||
bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
|
||||
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.functionCallbackResolver,
|
||||
this.toolFunctionCallbacks, this.observationRegistry);
|
||||
}
|
||||
else {
|
||||
bedrockProxyChatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient,
|
||||
this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry,
|
||||
|
||||
@@ -36,7 +36,6 @@ import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -367,7 +366,7 @@ class BedrockConverseChatClientIT {
|
||||
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
.options(FunctionCallingOptions.builder().model(modelName).build())
|
||||
.options(ToolCallingChatOptions.builder().model(modelName).build())
|
||||
.user(u -> u.text("Explain what do you see on this picture?")
|
||||
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))
|
||||
.call()
|
||||
@@ -389,7 +388,7 @@ class BedrockConverseChatClientIT {
|
||||
// @formatter:off
|
||||
String response = ChatClient.create(this.chatModel).prompt()
|
||||
// TODO consider adding model(...) method to ChatClient as a shortcut to
|
||||
.options(FunctionCallingOptions.builder().model(modelName).build())
|
||||
.options(ToolCallingChatOptions.builder().model(modelName).build())
|
||||
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
|
||||
.call()
|
||||
.content();
|
||||
|
||||
@@ -21,7 +21,7 @@ import java.time.Duration;
|
||||
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
|
||||
@@ -42,7 +42,7 @@ public class BedrockConverseTestConfiguration {
|
||||
.region(Region.US_EAST_1)
|
||||
// .region(Region.US_EAST_1)
|
||||
.timeout(Duration.ofSeconds(120))
|
||||
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
|
||||
.defaultOptions(ToolCallingChatOptions.builder().model(modelId).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -16,11 +16,7 @@
|
||||
|
||||
package org.springframework.ai.bedrock.converse;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
@@ -40,8 +36,9 @@ import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
|
||||
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
|
||||
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.ArgumentMatchers.isA;
|
||||
@@ -61,8 +58,10 @@ public class BedrockConverseUsageAggregationTests {
|
||||
|
||||
@BeforeEach
|
||||
public void beforeEach() {
|
||||
this.chatModel = new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient,
|
||||
FunctionCallingOptions.builder().build(), null, List.of(), ObservationRegistry.NOOP);
|
||||
this.chatModel = BedrockProxyChatModel.builder()
|
||||
.bedrockRuntimeClient(this.bedrockRuntimeClient)
|
||||
.bedrockRuntimeAsyncClient(this.bedrockRuntimeAsyncClient)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -138,14 +137,13 @@ public class BedrockConverseUsageAggregationTests {
|
||||
given(this.bedrockRuntimeClient.converse(isA(ConverseRequest.class))).willReturn(converseResponseToolUse)
|
||||
.willReturn(converseResponseFinal);
|
||||
|
||||
FunctionCallback functionCallback = FunctionCallback.builder()
|
||||
.function("getCurrentWeather", (Request request) -> "15.0°C")
|
||||
ToolCallback toolCallback = FunctionToolCallback.builder("getCurrentWeather", (Request request) -> "15.0°C")
|
||||
.description("Gets the weather in location")
|
||||
.inputType(Request.class)
|
||||
.build();
|
||||
|
||||
var result = this.chatModel.call(new Prompt("What is the weather in Paris?",
|
||||
FunctionCallingOptions.builder().functionCallbacks(functionCallback).build()));
|
||||
ToolCallingChatOptions.builder().toolCallbacks(toolCallback).build()));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.getResult().getOutput().getText())
|
||||
|
||||
@@ -92,7 +92,7 @@ class BedrockProxyChatModelIT {
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage),
|
||||
FunctionCallingOptions.builder().model(modelName).build());
|
||||
ToolCallingChatOptions.builder().model(modelName).build());
|
||||
ChatResponse response = this.chatModel.call(prompt);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0);
|
||||
@@ -128,7 +128,7 @@ class BedrockProxyChatModelIT {
|
||||
|
||||
@Test
|
||||
void streamingWithTokenUsage() {
|
||||
var promptOptions = FunctionCallingOptions.builder().temperature(0.0).build();
|
||||
var promptOptions = ToolCallingChatOptions.builder().temperature(0.0).build();
|
||||
|
||||
var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
|
||||
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
|
||||
@@ -255,9 +255,8 @@ class BedrockProxyChatModelIT {
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = FunctionCallingOptions.builder()
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
var promptOptions = ToolCallingChatOptions.builder()
|
||||
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description(
|
||||
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
@@ -305,10 +304,9 @@ class BedrockProxyChatModelIT {
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = FunctionCallingOptions.builder()
|
||||
var promptOptions = ToolCallingChatOptions.builder()
|
||||
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description(
|
||||
"Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
@@ -333,7 +331,7 @@ class BedrockProxyChatModelIT {
|
||||
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
// @formatter:off
|
||||
ChatResponse response = ChatClient.create(this.chatModel).prompt()
|
||||
.options(FunctionCallingOptions.builder().model(model).build())
|
||||
.options(ToolCallingChatOptions.builder().model(model).build())
|
||||
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
|
||||
.call()
|
||||
.chatResponse();
|
||||
@@ -348,7 +346,7 @@ class BedrockProxyChatModelIT {
|
||||
String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
// @formatter:off
|
||||
ChatResponse response = ChatClient.create(this.chatModel).prompt()
|
||||
.options(FunctionCallingOptions.builder().model(model).build())
|
||||
.options(ToolCallingChatOptions.builder().model(model).build())
|
||||
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
|
||||
.stream()
|
||||
.chatResponse()
|
||||
|
||||
@@ -34,7 +34,7 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.observation.conventions.AiOperationType;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -67,7 +67,7 @@ public class BedrockProxyChatModelObservationIT {
|
||||
|
||||
@Test
|
||||
void observationForChatOperation() {
|
||||
var options = FunctionCallingOptions.builder()
|
||||
var options = ToolCallingChatOptions.builder()
|
||||
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
.maxTokens(2048)
|
||||
.stopSequences(List.of("this-is-the-end"))
|
||||
@@ -89,7 +89,7 @@ public class BedrockProxyChatModelObservationIT {
|
||||
|
||||
@Test
|
||||
void observationForStreamingChatOperation() {
|
||||
var options = FunctionCallingOptions.builder()
|
||||
var options = ToolCallingChatOptions.builder()
|
||||
.model("anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
.maxTokens(2048)
|
||||
.stopSequences(List.of("this-is-the-end"))
|
||||
@@ -170,10 +170,10 @@ public class BedrockProxyChatModelObservationIT {
|
||||
String modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
|
||||
return BedrockProxyChatModel.builder()
|
||||
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.withRegion(Region.US_EAST_1)
|
||||
.withObservationRegistry(observationRegistry)
|
||||
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
|
||||
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.region(Region.US_EAST_1)
|
||||
.observationRegistry(observationRegistry)
|
||||
.defaultOptions(ToolCallingChatOptions.builder().model(modelId).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.model.Media;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
@@ -185,7 +186,7 @@ public class BedrockNovaChatClientIT {
|
||||
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.region(Region.US_EAST_1)
|
||||
.timeout(Duration.ofSeconds(120))
|
||||
.withDefaultOptions(FunctionCallingOptions.builder().model(modelId).build())
|
||||
.defaultOptions(ToolCallingChatOptions.builder().model(modelId).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -43,8 +43,8 @@ public final class BedrockConverseChatModelMain {
|
||||
var prompt = new Prompt("Tell me a joke?", ChatOptions.builder().model(modelId).build());
|
||||
|
||||
var chatModel = BedrockProxyChatModel.builder()
|
||||
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.withRegion(Region.US_EAST_1)
|
||||
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
var chatResponse = chatModel.call(prompt);
|
||||
|
||||
@@ -24,8 +24,8 @@ import software.amazon.awssdk.regions.Region;
|
||||
import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
|
||||
import org.springframework.ai.bedrock.converse.MockWeatherService;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
|
||||
/**
|
||||
* Used for reverse engineering the protocol
|
||||
@@ -48,18 +48,17 @@ public final class BedrockConverseChatModelMain3 {
|
||||
// "What's the weather like in San Francisco, Tokyo, and Paris? Return the
|
||||
// temperature in Celsius.",
|
||||
"What's the weather like in Paris? Return the temperature in Celsius.",
|
||||
FunctionCallingOptions.builder()
|
||||
ToolCallingChatOptions.builder()
|
||||
.model(modelId)
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function("getCurrentWeather", new MockWeatherService())
|
||||
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build());
|
||||
|
||||
BedrockProxyChatModel chatModel = BedrockProxyChatModel.builder()
|
||||
.withCredentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.withRegion(Region.US_EAST_1)
|
||||
.credentialsProvider(EnvironmentVariableCredentialsProvider.create())
|
||||
.region(Region.US_EAST_1)
|
||||
.build();
|
||||
|
||||
var response = chatModel.call(prompt);
|
||||
|
||||
@@ -89,7 +89,7 @@ import org.springframework.util.MimeType;
|
||||
* @author Alexandros Pappas
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class MistralAiChatModel extends AbstractToolCallSupport implements ChatModel {
|
||||
public class MistralAiChatModel implements ChatModel {
|
||||
|
||||
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
|
||||
|
||||
@@ -121,73 +121,9 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
|
||||
*/
|
||||
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link MistralAiChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi) {
|
||||
this(mistralAiApi,
|
||||
MistralAiChatOptions.builder()
|
||||
.temperature(0.7)
|
||||
.topP(1.0)
|
||||
.safePrompt(false)
|
||||
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
|
||||
.build());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link MistralAiChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
|
||||
this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link MistralAiChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver, @Nullable RetryTemplate retryTemplate) {
|
||||
this(mistralAiApi, options, functionCallbackResolver, List.of(), retryTemplate);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link MistralAiChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate) {
|
||||
this(mistralAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate,
|
||||
ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link MistralAiChatModel.Builder}.
|
||||
*/
|
||||
@Deprecated
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate,
|
||||
ObservationRegistry observationRegistry) {
|
||||
this(mistralAiApi, options,
|
||||
LegacyToolCallingManager.builder()
|
||||
.functionCallbackResolver(functionCallbackResolver)
|
||||
.functionCallbacks(toolFunctionCallbacks)
|
||||
.build(),
|
||||
retryTemplate, observationRegistry);
|
||||
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
|
||||
+ "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
|
||||
}
|
||||
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions defaultOptions,
|
||||
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate,
|
||||
ObservationRegistry observationRegistry) {
|
||||
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
|
||||
// because it modifies them. We are using ToolCallingManager instead,
|
||||
// so we just pass empty options here.
|
||||
super(null, MistralAiChatOptions.builder().build(), List.of());
|
||||
Assert.notNull(mistralAiApi, "mistralAiApi cannot be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
|
||||
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
|
||||
@@ -594,15 +530,11 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
|
||||
.temperature(0.7)
|
||||
.topP(1.0)
|
||||
.safePrompt(false)
|
||||
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
|
||||
.model(MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.build();
|
||||
|
||||
private ToolCallingManager toolCallingManager;
|
||||
|
||||
private FunctionCallbackResolver functionCallbackResolver;
|
||||
|
||||
private List<FunctionCallback> toolFunctionCallbacks;
|
||||
|
||||
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
|
||||
|
||||
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
@@ -625,18 +557,6 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
|
||||
this.functionCallbackResolver = functionCallbackResolver;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
|
||||
this.toolFunctionCallbacks = toolFunctionCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder retryTemplate(RetryTemplate retryTemplate) {
|
||||
this.retryTemplate = retryTemplate;
|
||||
return this;
|
||||
@@ -649,25 +569,9 @@ public class MistralAiChatModel extends AbstractToolCallSupport implements ChatM
|
||||
|
||||
public MistralAiChatModel build() {
|
||||
if (toolCallingManager != null) {
|
||||
Assert.isNull(functionCallbackResolver,
|
||||
"functionCallbackResolver cannot be set when toolCallingManager is set");
|
||||
Assert.isNull(toolFunctionCallbacks,
|
||||
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
|
||||
|
||||
return new MistralAiChatModel(mistralAiApi, defaultOptions, toolCallingManager, retryTemplate,
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
if (functionCallbackResolver != null) {
|
||||
Assert.isNull(toolCallingManager,
|
||||
"toolCallingManager cannot be set when functionCallbackResolver is set");
|
||||
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
|
||||
: List.of();
|
||||
|
||||
return new MistralAiChatModel(mistralAiApi, defaultOptions, functionCallbackResolver, toolCallbacks,
|
||||
retryTemplate, observationRegistry);
|
||||
}
|
||||
|
||||
return new MistralAiChatModel(mistralAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
@@ -265,12 +265,6 @@ public class MistralAiApi {
|
||||
public enum ChatModel implements ChatModelDescription {
|
||||
|
||||
// @formatter:off
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
OPEN_MISTRAL_7B("open-mistral-7b"),
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
OPEN_MIXTRAL_7B("open-mixtral-8x7b"),
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
OPEN_MIXTRAL_22B("open-mixtral-8x22b"),
|
||||
// Premier Models
|
||||
CODESTRAL("codestral-latest"),
|
||||
LARGE("mistral-large-latest"),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -32,7 +32,6 @@ import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
|
||||
import org.springframework.ai.observation.conventions.AiOperationType;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -70,7 +69,7 @@ public class MistralAiChatModelObservationIT {
|
||||
@Test
|
||||
void observationForChatOperation() {
|
||||
var options = MistralAiChatOptions.builder()
|
||||
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
|
||||
.model(MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.maxTokens(2048)
|
||||
.stop(List.of("this-is-the-end"))
|
||||
.temperature(0.7)
|
||||
@@ -91,7 +90,7 @@ public class MistralAiChatModelObservationIT {
|
||||
@Test
|
||||
void observationForStreamingChatOperation() {
|
||||
var options = MistralAiChatOptions.builder()
|
||||
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
|
||||
.model(MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.maxTokens(2048)
|
||||
.stop(List.of("this-is-the-end"))
|
||||
.temperature(0.7)
|
||||
@@ -125,12 +124,12 @@ public class MistralAiChatModelObservationIT {
|
||||
.doesNotHaveAnyRemainingCurrentObservation()
|
||||
.hasObservationWithNameEqualTo(DefaultChatModelObservationConvention.DEFAULT_NAME)
|
||||
.that()
|
||||
.hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
|
||||
.hasContextualNameEqualTo("chat " + MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
|
||||
AiOperationType.CHAT.value())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.MISTRAL_AI.value())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(),
|
||||
MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
|
||||
MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.RESPONSE_MODEL.asString(),
|
||||
StringUtils.hasText(responseMetadata.getModel()) ? responseMetadata.getModel()
|
||||
: KeyValue.NONE_VALUE)
|
||||
@@ -181,9 +180,12 @@ public class MistralAiChatModelObservationIT {
|
||||
@Bean
|
||||
public MistralAiChatModel openAiChatModel(MistralAiApi mistralAiApi,
|
||||
TestObservationRegistry observationRegistry) {
|
||||
return new MistralAiChatModel(mistralAiApi, MistralAiChatOptions.builder().build(),
|
||||
new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(),
|
||||
observationRegistry);
|
||||
return MistralAiChatModel.builder()
|
||||
.mistralAiApi(mistralAiApi)
|
||||
.defaultOptions(MistralAiChatOptions.builder().build())
|
||||
.retryTemplate(RetryTemplate.defaultInstance())
|
||||
.observationRegistry(observationRegistry)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -77,14 +77,16 @@ public class MistralAiRetryTests {
|
||||
this.retryListener = new TestRetryListener();
|
||||
this.retryTemplate.registerListener(this.retryListener);
|
||||
|
||||
this.chatModel = new MistralAiChatModel(this.mistralAiApi,
|
||||
MistralAiChatOptions.builder()
|
||||
.temperature(0.7)
|
||||
.topP(1.0)
|
||||
.safePrompt(false)
|
||||
.model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue())
|
||||
.build(),
|
||||
null, this.retryTemplate);
|
||||
this.chatModel = MistralAiChatModel.builder()
|
||||
.mistralAiApi(this.mistralAiApi)
|
||||
.defaultOptions(MistralAiChatOptions.builder()
|
||||
.temperature(0.7)
|
||||
.topP(1.0)
|
||||
.safePrompt(false)
|
||||
.model(MistralAiApi.ChatModel.SMALL.getValue())
|
||||
.build())
|
||||
.retryTemplate(this.retryTemplate)
|
||||
.build();
|
||||
this.embeddingModel = new MistralAiEmbeddingModel(this.mistralAiApi, MetadataMode.EMBED,
|
||||
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
|
||||
this.retryTemplate);
|
||||
|
||||
@@ -43,8 +43,10 @@ public class MistralAiTestConfiguration {
|
||||
|
||||
@Bean
|
||||
public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) {
|
||||
return new MistralAiChatModel(mistralAiApi,
|
||||
MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.OPEN_MIXTRAL_7B.getValue()).build());
|
||||
return MistralAiChatModel.builder()
|
||||
.mistralAiApi(mistralAiApi)
|
||||
.defaultOptions(MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.SMALL.getValue()).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ public class MistralAiApiIT {
|
||||
void chatCompletionEntity() {
|
||||
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
|
||||
ResponseEntity<ChatCompletion> response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest(
|
||||
List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false));
|
||||
List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
assertThat(response.getBody()).isNotNull();
|
||||
@@ -64,7 +64,7 @@ public class MistralAiApiIT {
|
||||
""", Role.SYSTEM);
|
||||
|
||||
ResponseEntity<ChatCompletion> response = this.mistralAiApi.chatCompletionEntity(new ChatCompletionRequest(
|
||||
List.of(systemMessage, userMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, false));
|
||||
List.of(systemMessage, userMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, false));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
assertThat(response.getBody()).isNotNull();
|
||||
@@ -74,7 +74,7 @@ public class MistralAiApiIT {
|
||||
void chatCompletionStream() {
|
||||
ChatCompletionMessage chatCompletionMessage = new ChatCompletionMessage("Hello world", Role.USER);
|
||||
Flux<ChatCompletionChunk> response = this.mistralAiApi.chatCompletionStream(new ChatCompletionRequest(
|
||||
List.of(chatCompletionMessage), MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue(), 0.8, true));
|
||||
List.of(chatCompletionMessage), MistralAiApi.ChatModel.SMALL.getValue(), 0.8, true));
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
assertThat(response.collectList().block()).isNotNull();
|
||||
|
||||
@@ -28,13 +28,6 @@ import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.model.tool.LegacyToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.util.json.JsonParser;
|
||||
import org.springframework.lang.Nullable;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
@@ -45,7 +38,6 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.DefaultUsage;
|
||||
import org.springframework.ai.chat.model.AbstractToolCallSupport;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -57,9 +49,9 @@ import org.springframework.ai.chat.observation.DefaultChatModelObservationConven
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
|
||||
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
|
||||
@@ -70,6 +62,8 @@ import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.ollama.management.ModelManagementOptions;
|
||||
import org.springframework.ai.ollama.management.OllamaModelManager;
|
||||
import org.springframework.ai.ollama.management.PullModelStrategy;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.util.json.JsonParser;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
@@ -89,7 +83,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class OllamaChatModel extends AbstractToolCallSupport implements ChatModel {
|
||||
public class OllamaChatModel implements ChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OllamaChatModel.class);
|
||||
|
||||
@@ -125,24 +119,8 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
|
||||
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
@Deprecated
|
||||
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry,
|
||||
ModelManagementOptions modelManagementOptions) {
|
||||
this(ollamaApi, defaultOptions, new LegacyToolCallingManager(functionCallbackResolver, toolFunctionCallbacks),
|
||||
observationRegistry, modelManagementOptions);
|
||||
|
||||
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
|
||||
+ "Please use the OllamaChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
|
||||
}
|
||||
|
||||
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
|
||||
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
|
||||
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
|
||||
// because it modifies them. We are using ToolCallingManager instead,
|
||||
// so we just pass empty options here.
|
||||
super(null, OllamaOptions.builder().build(), List.of());
|
||||
Assert.notNull(ollamaApi, "ollamaApi must not be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions must not be null");
|
||||
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
|
||||
@@ -381,10 +359,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
|
||||
OllamaOptions.class);
|
||||
}
|
||||
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
|
||||
OllamaOptions.class);
|
||||
}
|
||||
else {
|
||||
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
|
||||
OllamaOptions.class);
|
||||
@@ -540,10 +514,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
|
||||
private ToolCallingManager toolCallingManager;
|
||||
|
||||
private FunctionCallbackResolver functionCallbackResolver;
|
||||
|
||||
private List<FunctionCallback> toolFunctionCallbacks;
|
||||
|
||||
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
|
||||
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
|
||||
@@ -566,18 +536,6 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
|
||||
this.functionCallbackResolver = functionCallbackResolver;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
|
||||
this.toolFunctionCallbacks = toolFunctionCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder observationRegistry(ObservationRegistry observationRegistry) {
|
||||
this.observationRegistry = observationRegistry;
|
||||
return this;
|
||||
@@ -590,24 +548,9 @@ public class OllamaChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
|
||||
public OllamaChatModel build() {
|
||||
if (toolCallingManager != null) {
|
||||
Assert.isNull(functionCallbackResolver,
|
||||
"functionCallbackResolver must not be set when toolCallingManager is set");
|
||||
Assert.isNull(toolFunctionCallbacks,
|
||||
"toolFunctionCallbacks must not be set when toolCallingManager is set");
|
||||
|
||||
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
|
||||
this.observationRegistry, this.modelManagementOptions);
|
||||
}
|
||||
|
||||
if (functionCallbackResolver != null) {
|
||||
Assert.isNull(toolCallingManager,
|
||||
"toolCallingManager must not be set when functionCallbackResolver is set");
|
||||
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
|
||||
: List.of();
|
||||
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.functionCallbackResolver,
|
||||
toolCallbacks, this.observationRegistry, this.modelManagementOptions);
|
||||
}
|
||||
|
||||
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
|
||||
this.observationRegistry, this.modelManagementOptions);
|
||||
}
|
||||
|
||||
@@ -55,9 +55,11 @@ class OllamaChatModelTests {
|
||||
|
||||
@Test
|
||||
void buildOllamaChatModelWithDeprecatedConstructor() {
|
||||
ChatModel chatModel = new OllamaChatModel(this.ollamaApi,
|
||||
OllamaOptions.builder().model(OllamaModel.MISTRAL).build(), null, null, ObservationRegistry.NOOP,
|
||||
ModelManagementOptions.builder().build());
|
||||
ChatModel chatModel = OllamaChatModel.builder()
|
||||
.ollamaApi(this.ollamaApi)
|
||||
.defaultOptions(OllamaOptions.builder().model(OllamaModel.MISTRAL).build())
|
||||
.observationRegistry(ObservationRegistry.NOOP)
|
||||
.build();
|
||||
assertThat(chatModel).isNotNull();
|
||||
}
|
||||
|
||||
|
||||
@@ -29,12 +29,6 @@ import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.model.tool.LegacyToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.lang.Nullable;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
@@ -50,7 +44,6 @@ import org.springframework.ai.chat.metadata.EmptyUsage;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.metadata.UsageUtils;
|
||||
import org.springframework.ai.chat.model.AbstractToolCallSupport;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -67,6 +60,9 @@ import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolExecutionResult;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
|
||||
@@ -79,6 +75,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
|
||||
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.core.io.ByteArrayResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
@@ -111,7 +108,7 @@ import org.springframework.util.StringUtils;
|
||||
* @see StreamingChatModel
|
||||
* @see OpenAiApi
|
||||
*/
|
||||
public class OpenAiChatModel extends AbstractToolCallSupport implements ChatModel {
|
||||
public class OpenAiChatModel implements ChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class);
|
||||
|
||||
@@ -146,96 +143,8 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
*/
|
||||
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
/**
|
||||
* Creates an instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @throws IllegalArgumentException if openAiApi is null
|
||||
* @deprecated Use OpenAiChatModel.Builder.
|
||||
*/
|
||||
@Deprecated
|
||||
public OpenAiChatModel(OpenAiApi openAiApi) {
|
||||
this(openAiApi, OpenAiChatOptions.builder().model(OpenAiApi.DEFAULT_CHAT_MODEL).temperature(0.7).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes an instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @param options The OpenAiChatOptions to configure the chat model.
|
||||
* @deprecated Use OpenAiChatModel.Builder.
|
||||
*/
|
||||
@Deprecated
|
||||
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) {
|
||||
this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @param options The OpenAiChatOptions to configure the chat model.
|
||||
* @param functionCallbackResolver The function callback resolver.
|
||||
* @param retryTemplate The retry template.
|
||||
* @deprecated Use OpenAiChatModel.Builder.
|
||||
*/
|
||||
@Deprecated
|
||||
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver, RetryTemplate retryTemplate) {
|
||||
this(openAiApi, options, functionCallbackResolver, List.of(), retryTemplate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @param options The OpenAiChatOptions to configure the chat model.
|
||||
* @param functionCallbackResolver The function callback resolver.
|
||||
* @param toolFunctionCallbacks The tool function callbacks.
|
||||
* @param retryTemplate The retry template.
|
||||
* @deprecated Use OpenAiChatModel.Builder.
|
||||
*/
|
||||
@Deprecated
|
||||
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate) {
|
||||
this(openAiApi, options, functionCallbackResolver, toolFunctionCallbacks, retryTemplate,
|
||||
ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @param options The OpenAiChatOptions to configure the chat model.
|
||||
* @param functionCallbackResolver The function callback resolver.
|
||||
* @param toolFunctionCallbacks The tool function callbacks.
|
||||
* @param retryTemplate The retry template.
|
||||
* @param observationRegistry The ObservationRegistry used for instrumentation.
|
||||
* @deprecated Use OpenAiChatModel.Builder or OpenAiChatModel(OpenAiApi,
|
||||
* OpenAiChatOptions, ToolCallingManager, RetryTemplate, ObservationRegistry).
|
||||
*/
|
||||
@Deprecated
|
||||
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
|
||||
@Nullable FunctionCallbackResolver functionCallbackResolver,
|
||||
@Nullable List<FunctionCallback> toolFunctionCallbacks, RetryTemplate retryTemplate,
|
||||
ObservationRegistry observationRegistry) {
|
||||
this(openAiApi, options,
|
||||
LegacyToolCallingManager.builder()
|
||||
.functionCallbackResolver(functionCallbackResolver)
|
||||
.functionCallbacks(toolFunctionCallbacks)
|
||||
.build(),
|
||||
retryTemplate, observationRegistry);
|
||||
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
|
||||
+ "Please use the OpenAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
|
||||
}
|
||||
|
||||
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager,
|
||||
RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
|
||||
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
|
||||
// because it modifies them. We are using ToolCallingManager instead,
|
||||
// so we just pass empty options here.
|
||||
super(null, OpenAiChatOptions.builder().build(), List.of());
|
||||
Assert.notNull(openAiApi, "openAiApi cannot be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
|
||||
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
|
||||
@@ -777,10 +686,6 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
|
||||
private ToolCallingManager toolCallingManager;
|
||||
|
||||
private FunctionCallbackResolver functionCallbackResolver;
|
||||
|
||||
private List<FunctionCallback> toolFunctionCallbacks;
|
||||
|
||||
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
|
||||
|
||||
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
@@ -803,18 +708,6 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
|
||||
this.functionCallbackResolver = functionCallbackResolver;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
|
||||
this.toolFunctionCallbacks = toolFunctionCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder retryTemplate(RetryTemplate retryTemplate) {
|
||||
this.retryTemplate = retryTemplate;
|
||||
return this;
|
||||
@@ -827,25 +720,9 @@ public class OpenAiChatModel extends AbstractToolCallSupport implements ChatMode
|
||||
|
||||
public OpenAiChatModel build() {
|
||||
if (toolCallingManager != null) {
|
||||
Assert.isNull(functionCallbackResolver,
|
||||
"functionCallbackResolver cannot be set when toolCallingManager is set");
|
||||
Assert.isNull(toolFunctionCallbacks,
|
||||
"toolFunctionCallbacks cannot be set when toolCallingManager is set");
|
||||
|
||||
return new OpenAiChatModel(openAiApi, defaultOptions, toolCallingManager, retryTemplate,
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
if (functionCallbackResolver != null) {
|
||||
Assert.isNull(toolCallingManager,
|
||||
"toolCallingManager cannot be set when functionCallbackResolver is set");
|
||||
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
|
||||
: List.of();
|
||||
|
||||
return new OpenAiChatModel(openAiApi, defaultOptions, functionCallbackResolver, toolCallbacks,
|
||||
retryTemplate, observationRegistry);
|
||||
}
|
||||
|
||||
return new OpenAiChatModel(openAiApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
|
||||
observationRegistry);
|
||||
}
|
||||
|
||||
@@ -85,97 +85,6 @@ public class OpenAiApi {
|
||||
|
||||
private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();
|
||||
|
||||
/**
|
||||
* Create a new chat completion api with base URL set to https://api.openai.com
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
|
||||
*/
|
||||
@Deprecated(since = "1.0.0.M6")
|
||||
public OpenAiApi(String apiKey) {
|
||||
this(OpenAiApiConstants.DEFAULT_BASE_URL, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chat completion api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
|
||||
*/
|
||||
@Deprecated(since = "1.0.0.M6")
|
||||
public OpenAiApi(String baseUrl, String apiKey) {
|
||||
this(baseUrl, apiKey, RestClient.builder(), WebClient.builder());
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chat completion api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param webClientBuilder WebClient builder.
|
||||
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
|
||||
*/
|
||||
@Deprecated(since = "1.0.0.M6")
|
||||
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
|
||||
WebClient.Builder webClientBuilder) {
|
||||
this(baseUrl, apiKey, restClientBuilder, webClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chat completion api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param webClientBuilder WebClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
|
||||
*/
|
||||
@Deprecated(since = "1.0.0.M6")
|
||||
public OpenAiApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
|
||||
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, apiKey, "/v1/chat/completions", "/v1/embeddings", restClientBuilder, webClientBuilder,
|
||||
responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chat completion api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param completionsPath the path to the chat completions endpoint.
|
||||
* @param embeddingsPath the path to the embeddings endpoint.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param webClientBuilder WebClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
|
||||
*/
|
||||
@Deprecated(since = "1.0.0.M6")
|
||||
public OpenAiApi(String baseUrl, String apiKey, String completionsPath, String embeddingsPath,
|
||||
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), completionsPath, embeddingsPath,
|
||||
restClientBuilder, webClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chat completion api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param headers the http headers to use.
|
||||
* @param completionsPath the path to the chat completions endpoint.
|
||||
* @param embeddingsPath the path to the embeddings endpoint.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param webClientBuilder WebClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
* @deprecated since 1.0.0.M6 - use {@link #builder()} instead
|
||||
*/
|
||||
@Deprecated(since = "1.0.0.M6")
|
||||
public OpenAiApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers, String completionsPath,
|
||||
String embeddingsPath, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, new SimpleApiKey(apiKey), headers, completionsPath, embeddingsPath, restClientBuilder,
|
||||
webClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new chat completion api.
|
||||
* @param baseUrl api base URL.
|
||||
|
||||
@@ -57,79 +57,6 @@ public class OpenAiAudioApi {
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
/**
|
||||
* Create a new audio api.
|
||||
* @param openAiToken OpenAI apiKey.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiAudioApi(String openAiToken) {
|
||||
this(OpenAiApiConstants.DEFAULT_BASE_URL, openAiToken, RestClient.builder(), WebClient.builder(),
|
||||
RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new audio api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param openAiToken OpenAI apiKey.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
Consumer<HttpHeaders> authHeaders;
|
||||
if (openAiToken != null && !openAiToken.isEmpty()) {
|
||||
authHeaders = h -> h.setBearerAuth(openAiToken);
|
||||
}
|
||||
else {
|
||||
authHeaders = h -> {
|
||||
};
|
||||
}
|
||||
|
||||
this.restClient = restClientBuilder.baseUrl(baseUrl)
|
||||
.defaultHeaders(authHeaders)
|
||||
.defaultStatusHandler(responseErrorHandler)
|
||||
.build();
|
||||
|
||||
this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(authHeaders).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new audio api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param webClientBuilder WebClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiAudioApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
|
||||
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, webClientBuilder,
|
||||
responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new audio api.
|
||||
* @param baseUrl api base URL.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param headers the http headers to use.
|
||||
* @param restClientBuilder RestClient builder.
|
||||
* @param webClientBuilder WebClient builder.
|
||||
* @param responseErrorHandler Response error handler.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiAudioApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
|
||||
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, webClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new audio api.
|
||||
* @param baseUrl api base URL.
|
||||
@@ -496,71 +423,26 @@ public class OpenAiAudioApi {
|
||||
|
||||
private Float speed;
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #model(String)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder model(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #input(String)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withInput(String input) {
|
||||
this.input = input;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder input(String input) {
|
||||
this.input = input;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #voice(Voice)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withVoice(Voice voice) {
|
||||
this.voice = voice;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder voice(Voice voice) {
|
||||
this.voice = voice;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #responseFormat(AudioResponseFormat)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withResponseFormat(AudioResponseFormat responseFormat) {
|
||||
this.responseFormat = responseFormat;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder responseFormat(AudioResponseFormat responseFormat) {
|
||||
this.responseFormat = responseFormat;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #speed(Float)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withSpeed(Float speed) {
|
||||
this.speed = speed;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder speed(Float speed) {
|
||||
this.speed = speed;
|
||||
return this;
|
||||
@@ -653,99 +535,36 @@ public class OpenAiAudioApi {
|
||||
|
||||
private GranularityType granularityType;
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #file(byte[])} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withFile(byte[] file) {
|
||||
this.file = file;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder file(byte[] file) {
|
||||
this.file = file;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #model(String)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder model(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #language(String)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withLanguage(String language) {
|
||||
this.language = language;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder language(String language) {
|
||||
this.language = language;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #prompt(String)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withPrompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder prompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #responseFormat(TranscriptResponseFormat)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withResponseFormat(TranscriptResponseFormat response_format) {
|
||||
this.responseFormat = response_format;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder responseFormat(TranscriptResponseFormat responseFormat) {
|
||||
this.responseFormat = responseFormat;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #temperature(Float)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withTemperature(Float temperature) {
|
||||
this.temperature = temperature;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder temperature(Float temperature) {
|
||||
this.temperature = temperature;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #granularityType(GranularityType)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withGranularityType(GranularityType granularityType) {
|
||||
this.granularityType = granularityType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder granularityType(GranularityType granularityType) {
|
||||
this.granularityType = granularityType;
|
||||
return this;
|
||||
@@ -805,71 +624,26 @@ public class OpenAiAudioApi {
|
||||
|
||||
private Float temperature;
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #file(byte[])} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withFile(byte[] file) {
|
||||
this.file = file;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder file(byte[] file) {
|
||||
this.file = file;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #model(String)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withModel(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder model(String model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #prompt(String)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withPrompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder prompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #responseFormat(TranscriptResponseFormat)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withResponseFormat(TranscriptResponseFormat responseFormat) {
|
||||
this.responseFormat = responseFormat;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder responseFormat(TranscriptResponseFormat responseFormat) {
|
||||
this.responseFormat = responseFormat;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #temperature(Float)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public Builder withTemperature(Float temperature) {
|
||||
this.temperature = temperature;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder temperature(Float temperature) {
|
||||
this.temperature = temperature;
|
||||
return this;
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
package org.springframework.ai.openai.api;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
@@ -30,7 +29,6 @@ import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
@@ -47,56 +45,6 @@ public class OpenAiImageApi {
|
||||
|
||||
private final RestClient restClient;
|
||||
|
||||
/**
|
||||
* Create a new OpenAI Image api with base URL set to {@code https://api.openai.com}.
|
||||
* @param openAiToken OpenAI apiKey.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiImageApi(String openAiToken) {
|
||||
this(OpenAiApiConstants.DEFAULT_BASE_URL, openAiToken, RestClient.builder());
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new OpenAI Image API with the provided base URL.
|
||||
* @param baseUrl the base URL for the OpenAI API.
|
||||
* @param openAiToken OpenAI apiKey.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
|
||||
this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new OpenAI Image API with the provided base URL.
|
||||
* @param baseUrl the base URL for the OpenAI API.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param restClientBuilder the rest client builder to use.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiImageApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new OpenAI Image API with the provided base URL.
|
||||
* @param baseUrl the base URL for the OpenAI API.
|
||||
* @param apiKey OpenAI apiKey.
|
||||
* @param headers the http headers to use.
|
||||
* @param restClientBuilder the rest client builder to use.
|
||||
* @param responseErrorHandler the response error handler to use.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiImageApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
|
||||
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new OpenAI Image API with the provided base URL.
|
||||
* @param baseUrl the base URL for the OpenAI API.
|
||||
|
||||
@@ -52,33 +52,6 @@ public class OpenAiModerationApi {
|
||||
|
||||
private final ObjectMapper objectMapper;
|
||||
|
||||
/**
|
||||
* Create a new OpenAI Moderation api with base URL set to https://api.openai.com
|
||||
* @param openAiToken OpenAI apiKey.
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiModerationApi(String openAiToken) {
|
||||
this(DEFAULT_BASE_URL, openAiToken, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link Builder} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
public OpenAiModerationApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||
|
||||
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> {
|
||||
if (openAiToken != null && !openAiToken.isEmpty()) {
|
||||
h.setBearerAuth(openAiToken);
|
||||
}
|
||||
h.setContentType(MediaType.APPLICATION_JSON);
|
||||
}).defaultStatusHandler(responseErrorHandler).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new OpenAI Moderation API with the provided base URL.
|
||||
* @param baseUrl the base URL for the OpenAI API.
|
||||
|
||||
@@ -74,8 +74,10 @@ class ChatCompletionRequestTests {
|
||||
|
||||
@Test
|
||||
void createRequestWithChatOptions() {
|
||||
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
|
||||
OpenAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build());
|
||||
var client = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build())
|
||||
.build();
|
||||
|
||||
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));
|
||||
|
||||
@@ -101,8 +103,10 @@ class ChatCompletionRequestTests {
|
||||
void promptOptionsTools() {
|
||||
final String TOOL_FUNCTION_NAME = "CurrentWeather";
|
||||
|
||||
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
|
||||
OpenAiChatOptions.builder().model("DEFAULT_MODEL").build());
|
||||
var client = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
|
||||
.defaultOptions(OpenAiChatOptions.builder().model("DEFAULT_MODEL").build())
|
||||
.build();
|
||||
|
||||
var prompt = client.buildRequestPrompt(new Prompt("Test message content",
|
||||
OpenAiChatOptions.builder()
|
||||
@@ -128,15 +132,17 @@ class ChatCompletionRequestTests {
|
||||
void defaultOptionsTools() {
|
||||
final String TOOL_FUNCTION_NAME = "CurrentWeather";
|
||||
|
||||
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
|
||||
OpenAiChatOptions.builder()
|
||||
.model("DEFAULT_MODEL")
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build());
|
||||
var client = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder().apiKey("TEST").build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model("DEFAULT_MODEL")
|
||||
.functionCallbacks(List.of(FunctionCallback.builder()
|
||||
.function(TOOL_FUNCTION_NAME, new MockWeatherService())
|
||||
.description("Get the weather in location")
|
||||
.inputType(MockWeatherService.Request.class)
|
||||
.build()))
|
||||
.build())
|
||||
.build();
|
||||
|
||||
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));
|
||||
|
||||
|
||||
@@ -61,28 +61,25 @@ public class OpenAiTestConfiguration {
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiChatModel(OpenAiApi api) {
|
||||
OpenAiChatModel openAiChatModel = new OpenAiChatModel(api,
|
||||
OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI).build());
|
||||
return openAiChatModel;
|
||||
return OpenAiChatModel.builder()
|
||||
.openAiApi(api)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiAudioTranscriptionModel openAiTranscriptionModel(OpenAiAudioApi api) {
|
||||
OpenAiAudioTranscriptionModel openAiTranscriptionModel = new OpenAiAudioTranscriptionModel(api);
|
||||
return openAiTranscriptionModel;
|
||||
return new OpenAiAudioTranscriptionModel(api);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) {
|
||||
OpenAiAudioSpeechModel openAiAudioSpeechModel = new OpenAiAudioSpeechModel(api);
|
||||
return openAiAudioSpeechModel;
|
||||
return new OpenAiAudioSpeechModel(api);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) {
|
||||
OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi);
|
||||
// openAiImageModel.setModel("foobar");
|
||||
return openAiImageModel;
|
||||
return new OpenAiImageModel(imageApi);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@@ -92,8 +89,7 @@ public class OpenAiTestConfiguration {
|
||||
|
||||
@Bean
|
||||
public OpenAiModerationModel openAiModerationClient(OpenAiModerationApi openAiModerationApi) {
|
||||
OpenAiModerationModel openAiModerationModel = new OpenAiModerationModel(openAiModerationApi);
|
||||
return openAiModerationModel;
|
||||
return new OpenAiModerationModel(openAiModerationApi);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
public class OpenAiApiIT {
|
||||
|
||||
OpenAiApi openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
OpenAiApi openAiApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
|
||||
|
||||
@Test
|
||||
void chatCompletionEntity() {
|
||||
|
||||
@@ -52,7 +52,7 @@ public class OpenAiApiToolFunctionCallIT {
|
||||
|
||||
MockWeatherService weatherService = new MockWeatherService();
|
||||
|
||||
OpenAiApi completionApi = new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
OpenAiApi completionApi = OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
|
||||
|
||||
private static <T> T fromJson(String json, Class<T> targetClass) {
|
||||
try {
|
||||
|
||||
@@ -74,7 +74,7 @@ public class MessageTypeContentTests {
|
||||
|
||||
@BeforeEach
|
||||
public void beforeEach() {
|
||||
this.chatModel = new OpenAiChatModel(this.openAiApi);
|
||||
this.chatModel = OpenAiChatModel.builder().openAiApi(this.openAiApi).build();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -73,7 +73,7 @@ public class OpenAiChatModelAdditionalHttpHeadersIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi);
|
||||
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -219,12 +219,12 @@ class OpenAiChatModelFunctionCallingIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi);
|
||||
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ public class OpenAiChatModelNoOpApiKeysIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi);
|
||||
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
|
||||
import org.springframework.ai.model.tool.DefaultToolCallingManager;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.observation.conventions.AiOperationType;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
@@ -169,14 +171,13 @@ public class OpenAiChatModelObservationIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi openAiApi() {
|
||||
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi, TestObservationRegistry observationRegistry) {
|
||||
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().build(),
|
||||
new DefaultFunctionCallbackResolver(), List.of(), RetryTemplate.defaultInstance(),
|
||||
observationRegistry);
|
||||
ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,372 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.openai.chat;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.JsonMappingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingHelper;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = OpenAiChatModelProxyToolCallsIT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
class OpenAiChatModelProxyToolCallsIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelProxyToolCallsIT.class);
|
||||
|
||||
private static final String DEFAULT_MODEL = "gpt-4o-mini";
|
||||
|
||||
FunctionCallback functionDefinition = new FunctionCallingHelper.FunctionDefinition("getWeatherInLocation",
|
||||
"Get the weather in location", """
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["C", "F"]
|
||||
}
|
||||
},
|
||||
"required": ["location", "unit"]
|
||||
}
|
||||
""");
|
||||
|
||||
@Autowired
|
||||
private OpenAiChatModel chatModel;
|
||||
|
||||
// Helper class that reuses some of the {@link AbstractToolCallSupport} functionality
|
||||
// to help to implement the function call handling logic on the client side.
|
||||
private FunctionCallingHelper functionCallingHelper = new FunctionCallingHelper();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static Map<String, String> getFunctionArguments(String functionArguments) {
|
||||
try {
|
||||
return new ObjectMapper().readValue(functionArguments, Map.class);
|
||||
}
|
||||
catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Function which will be called by the AI model.
|
||||
private String getWeatherInLocation(String location, String unit) {
|
||||
|
||||
double temperature = 0;
|
||||
|
||||
if (location.contains("Paris")) {
|
||||
temperature = 15;
|
||||
}
|
||||
else if (location.contains("Tokyo")) {
|
||||
temperature = 10;
|
||||
}
|
||||
else if (location.contains("San Francisco")) {
|
||||
temperature = 30;
|
||||
}
|
||||
|
||||
return String.format("The weather in %s is %s%s", location, temperature, unit);
|
||||
}
|
||||
|
||||
@Test
|
||||
void functionCall() throws JsonMappingException, JsonProcessingException {
|
||||
|
||||
List<Message> messages = List
|
||||
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
|
||||
|
||||
var prompt = new Prompt(messages, promptOptions);
|
||||
|
||||
boolean isToolCall = false;
|
||||
|
||||
ChatResponse chatResponse = null;
|
||||
|
||||
do {
|
||||
|
||||
chatResponse = this.chatModel.call(prompt);
|
||||
|
||||
// We will have to convert the chatResponse into OpenAI assistant message.
|
||||
|
||||
// Note that the tool call check could be platform specific because the finish
|
||||
// reasons.
|
||||
isToolCall = this.functionCallingHelper.isToolCall(chatResponse,
|
||||
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
|
||||
OpenAiApi.ChatCompletionFinishReason.STOP.name()));
|
||||
|
||||
if (isToolCall) {
|
||||
|
||||
Optional<Generation> toolCallGeneration = chatResponse.getResults()
|
||||
.stream()
|
||||
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
|
||||
.findFirst();
|
||||
|
||||
assertThat(toolCallGeneration).isNotEmpty();
|
||||
|
||||
AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
|
||||
|
||||
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
|
||||
|
||||
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
|
||||
|
||||
var functionName = toolCall.name();
|
||||
|
||||
assertThat(functionName).isEqualTo("getWeatherInLocation");
|
||||
|
||||
String functionArguments = toolCall.arguments();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, String> argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class);
|
||||
|
||||
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
|
||||
argumentsMap.get("unit").toString());
|
||||
|
||||
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName,
|
||||
ModelOptionsUtils.toJsonString(functionResponse)));
|
||||
}
|
||||
|
||||
ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of());
|
||||
|
||||
List<Message> toolCallConversation = this.functionCallingHelper
|
||||
.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse);
|
||||
|
||||
assertThat(toolCallConversation).isNotEmpty();
|
||||
|
||||
prompt = new Prompt(toolCallConversation, prompt.getOptions());
|
||||
}
|
||||
}
|
||||
while (isToolCall);
|
||||
|
||||
logger.info("Response: {}", chatResponse);
|
||||
|
||||
assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void functionStream() throws JsonMappingException, JsonProcessingException {
|
||||
|
||||
List<Message> messages = List
|
||||
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
|
||||
|
||||
var prompt = new Prompt(messages, promptOptions);
|
||||
|
||||
String response = processToolCall(prompt, Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
|
||||
OpenAiApi.ChatCompletionFinishReason.STOP.name()), toolCall -> {
|
||||
|
||||
var functionName = toolCall.name();
|
||||
|
||||
assertThat(functionName).isEqualTo("getWeatherInLocation");
|
||||
|
||||
String functionArguments = toolCall.arguments();
|
||||
|
||||
Map<String, String> argumentsMap = getFunctionArguments(functionArguments);
|
||||
|
||||
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
|
||||
argumentsMap.get("unit").toString());
|
||||
|
||||
return functionResponse;
|
||||
})
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(cr -> cr.getResult().getOutput().getText())
|
||||
.collect(Collectors.joining());
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
assertThat(response).contains("30", "10", "15");
|
||||
|
||||
}
|
||||
|
||||
private Flux<ChatResponse> processToolCall(Prompt prompt, Set<String> finishReasons,
|
||||
Function<AssistantMessage.ToolCall, String> customFunction) {
|
||||
|
||||
Flux<ChatResponse> chatResponses = this.chatModel.stream(prompt);
|
||||
|
||||
return chatResponses.flatMap(chatResponse -> {
|
||||
|
||||
boolean isToolCall = this.functionCallingHelper.isToolCall(chatResponse, finishReasons);
|
||||
|
||||
if (isToolCall) {
|
||||
|
||||
Optional<Generation> toolCallGeneration = chatResponse.getResults()
|
||||
.stream()
|
||||
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
|
||||
.findFirst();
|
||||
|
||||
assertThat(toolCallGeneration).isNotEmpty();
|
||||
|
||||
AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
|
||||
|
||||
List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();
|
||||
|
||||
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
|
||||
|
||||
String functionResponse = customFunction.apply(toolCall);
|
||||
|
||||
toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(),
|
||||
ModelOptionsUtils.toJsonString(functionResponse)));
|
||||
}
|
||||
|
||||
ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of());
|
||||
|
||||
List<Message> toolCallConversation = this.functionCallingHelper
|
||||
.buildToolCallConversation(prompt.getInstructions(), assistantMessage, toolMessageResponse);
|
||||
|
||||
assertThat(toolCallConversation).isNotEmpty();
|
||||
|
||||
var prompt2 = new Prompt(toolCallConversation, prompt.getOptions());
|
||||
|
||||
return processToolCall(prompt2, finishReasons, customFunction);
|
||||
}
|
||||
|
||||
return Flux.just(chatResponse);
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void functionCall2() throws JsonMappingException, JsonProcessingException {
|
||||
|
||||
List<Message> messages = List
|
||||
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
|
||||
|
||||
var prompt = new Prompt(messages, promptOptions);
|
||||
|
||||
ChatResponse chatResponse = this.functionCallingHelper.processCall(this.chatModel, prompt,
|
||||
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
|
||||
OpenAiApi.ChatCompletionFinishReason.STOP.name()),
|
||||
toolCall -> {
|
||||
|
||||
var functionName = toolCall.name();
|
||||
|
||||
assertThat(functionName).isEqualTo("getWeatherInLocation");
|
||||
|
||||
String functionArguments = toolCall.arguments();
|
||||
|
||||
Map<String, String> argumentsMap = getFunctionArguments(functionArguments);
|
||||
|
||||
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
|
||||
argumentsMap.get("unit").toString());
|
||||
|
||||
return functionResponse;
|
||||
});
|
||||
|
||||
logger.info("Response: {}", chatResponse);
|
||||
|
||||
assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void functionStream2() throws JsonMappingException, JsonProcessingException {
|
||||
|
||||
List<Message> messages = List
|
||||
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder().functionCallbacks(List.of(this.functionDefinition)).build();
|
||||
|
||||
var prompt = new Prompt(messages, promptOptions);
|
||||
|
||||
Flux<ChatResponse> responses = this.functionCallingHelper.processStream(this.chatModel, prompt,
|
||||
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
|
||||
OpenAiApi.ChatCompletionFinishReason.STOP.name()),
|
||||
toolCall -> {
|
||||
|
||||
var functionName = toolCall.name();
|
||||
|
||||
assertThat(functionName).isEqualTo("getWeatherInLocation");
|
||||
|
||||
String functionArguments = toolCall.arguments();
|
||||
|
||||
Map<String, String> argumentsMap = getFunctionArguments(functionArguments);
|
||||
|
||||
String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
|
||||
argumentsMap.get("unit").toString());
|
||||
|
||||
return functionResponse;
|
||||
});
|
||||
|
||||
String response = responses.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(cr -> cr.getResult().getOutput().getText())
|
||||
.collect(Collectors.joining());
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
assertThat(response).contains("30", "10", "15");
|
||||
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi, List<FunctionCallback> toolFunctionCallbacks) {
|
||||
// enable the proxy tool calls option.
|
||||
var options = OpenAiChatOptions.builder().model(DEFAULT_MODEL).proxyToolCalls(true).build();
|
||||
|
||||
return new OpenAiChatModel(openAiApi, options, null, toolFunctionCallbacks,
|
||||
RetryUtils.DEFAULT_RETRY_TEMPLATE, ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -234,12 +234,12 @@ public class OpenAiChatModelResponseFormatIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi);
|
||||
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ package org.springframework.ai.openai.chat;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
import org.hamcrest.core.StringContains;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
@@ -133,7 +134,7 @@ public class OpenAiChatModelWithChatResponseMetadataTests {
|
||||
httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_REMAINING_HEADER.getName(), "112358");
|
||||
httpHeaders.set(OpenAiApiResponseHeaders.TOKENS_RESET_HEADER.getName(), "27h55s451ms");
|
||||
|
||||
this.server.expect(requestTo("/v1/chat/completions"))
|
||||
this.server.expect(requestTo(StringContains.containsString("/v1/chat/completions")))
|
||||
.andExpect(method(HttpMethod.POST))
|
||||
.andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY))
|
||||
.andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders));
|
||||
@@ -169,12 +170,16 @@ public class OpenAiChatModelWithChatResponseMetadataTests {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi(RestClient.Builder builder, WebClient.Builder webClientBuilder) {
|
||||
return new OpenAiApi("", TEST_API_KEY, builder, webClientBuilder);
|
||||
return OpenAiApi.builder()
|
||||
.apiKey(TEST_API_KEY)
|
||||
.restClientBuilder(builder)
|
||||
.webClientBuilder(webClientBuilder)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi);
|
||||
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -55,7 +55,10 @@ public class OpenAiCompatibleChatModelIT {
|
||||
static Stream<ChatModel> openAiCompatibleApis() {
|
||||
Stream.Builder<ChatModel> builder = Stream.builder();
|
||||
|
||||
builder.add(new OpenAiChatModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")), forModelName("gpt-3.5-turbo")));
|
||||
builder.add(OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build())
|
||||
.defaultOptions(forModelName("gpt-3.5-turbo"))
|
||||
.build());
|
||||
|
||||
// (26.01.2025) Disable because the Groq API is down. TODO: Re-enable when the API
|
||||
// is back up.
|
||||
@@ -66,9 +69,13 @@ public class OpenAiCompatibleChatModelIT {
|
||||
// }
|
||||
|
||||
if (System.getenv("OPEN_ROUTER_API_KEY") != null) {
|
||||
builder.add(new OpenAiChatModel(
|
||||
new OpenAiApi("https://openrouter.ai/api", System.getenv("OPEN_ROUTER_API_KEY")),
|
||||
forModelName("meta-llama/llama-3-8b-instruct")));
|
||||
builder.add(OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl("https://openrouter.ai/api")
|
||||
.apiKey(System.getenv("OPEN_ROUTER_API_KEY"))
|
||||
.build())
|
||||
.defaultOptions(forModelName("meta-llama/llama-3-8b-instruct"))
|
||||
.build());
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
|
||||
@@ -16,11 +16,13 @@
|
||||
|
||||
package org.springframework.ai.openai.chat;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
@@ -34,19 +36,28 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
|
||||
import org.springframework.ai.tool.execution.ToolExecutionExceptionProcessor;
|
||||
import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver;
|
||||
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
|
||||
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
|
||||
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
|
||||
import org.springframework.beans.factory.ObjectProvider;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Description;
|
||||
import org.springframework.context.support.GenericApplicationContext;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
@@ -71,7 +82,7 @@ public class OpenAiPaymentTransactionIT {
|
||||
public void transactionPaymentStatuses(String functionName) {
|
||||
List<TransactionStatusResponse> content = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.functions(functionName)
|
||||
.tools(functionName)
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
""")
|
||||
@@ -102,7 +113,7 @@ public class OpenAiPaymentTransactionIT {
|
||||
|
||||
Flux<String> flux = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.functions(functionName)
|
||||
.tools(functionName)
|
||||
.user(u -> u.text("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
|
||||
@@ -217,25 +228,56 @@ public class OpenAiPaymentTransactionIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi, FunctionCallbackResolver functionCallbackResolver) {
|
||||
return new OpenAiChatModel(openAiApi,
|
||||
OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI.getName()).temperature(0.1).build(),
|
||||
functionCallbackResolver, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi, ToolCallingManager toolCallingManager) {
|
||||
return OpenAiChatModel.builder()
|
||||
.openAiApi(openAiApi)
|
||||
.toolCallingManager(toolCallingManager)
|
||||
.defaultOptions(
|
||||
OpenAiChatOptions.builder().model(ChatModel.GPT_4_O_MINI.getName()).temperature(0.1).build())
|
||||
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
|
||||
.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Because of the OPEN_API_SCHEMA type, the FunctionCallbackResolver instance must
|
||||
* different from the other JSON schema types.
|
||||
*/
|
||||
@Bean
|
||||
public FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
|
||||
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
|
||||
manager.setApplicationContext(context);
|
||||
return manager;
|
||||
@ConditionalOnMissingBean
|
||||
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
|
||||
List<FunctionCallback> functionCallbacks, List<ToolCallbackProvider> tcbProviders) {
|
||||
|
||||
List<FunctionCallback> allFunctionAndToolCallbacks = new ArrayList<>(functionCallbacks);
|
||||
tcbProviders.stream()
|
||||
.map(pr -> List.of(pr.getToolCallbacks()))
|
||||
.forEach(allFunctionAndToolCallbacks::addAll);
|
||||
|
||||
var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks);
|
||||
|
||||
var springBeanToolCallbackResolver = SpringBeanToolCallbackResolver.builder()
|
||||
.applicationContext(applicationContext)
|
||||
.build();
|
||||
|
||||
return new DelegatingToolCallbackResolver(
|
||||
List.of(staticToolCallbackResolver, springBeanToolCallbackResolver));
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() {
|
||||
return new DefaultToolExecutionExceptionProcessor(false);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
ToolCallingManager toolCallingManager(ToolCallbackResolver toolCallbackResolver,
|
||||
ToolExecutionExceptionProcessor toolExecutionExceptionProcessor,
|
||||
ObjectProvider<ObservationRegistry> observationRegistry) {
|
||||
return ToolCallingManager.builder()
|
||||
.observationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
|
||||
.toolCallbackResolver(toolCallbackResolver)
|
||||
.toolExecutionExceptionProcessor(toolExecutionExceptionProcessor)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -110,8 +110,11 @@ public class OpenAiRetryTests {
|
||||
this.retryListener = new TestRetryListener();
|
||||
this.retryTemplate.registerListener(this.retryListener);
|
||||
|
||||
this.chatModel = new OpenAiChatModel(this.openAiApi, OpenAiChatOptions.builder().build(), null,
|
||||
this.retryTemplate);
|
||||
this.chatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(this.openAiApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().build())
|
||||
.retryTemplate(this.retryTemplate)
|
||||
.build();
|
||||
this.embeddingModel = new OpenAiEmbeddingModel(this.openAiApi, MetadataMode.EMBED,
|
||||
OpenAiEmbeddingOptions.builder().build(), this.retryTemplate);
|
||||
this.audioTranscriptionModel = new OpenAiAudioTranscriptionModel(this.openAiAudioApi,
|
||||
|
||||
@@ -330,12 +330,15 @@ class DeepSeekWithOpenAiChatModelIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(DEEPSEEK_BASE_URL, System.getenv("DEEPSEEK_API_KEY"));
|
||||
return OpenAiApi.builder().baseUrl(DEEPSEEK_BASE_URL).apiKey(System.getenv("DEEPSEEK_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_DEEPSEEK_MODEL).build());
|
||||
return OpenAiChatModel.builder()
|
||||
.openAiApi(openAiApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_DEEPSEEK_MODEL).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -383,12 +383,15 @@ class GroqWithOpenAiChatModelIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(GROQ_BASE_URL, System.getenv("GROQ_API_KEY"));
|
||||
return OpenAiApi.builder().baseUrl(GROQ_BASE_URL).apiKey(System.getenv("GROQ_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_GROQ_MODEL).build());
|
||||
return OpenAiChatModel.builder()
|
||||
.openAiApi(openAiApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_GROQ_MODEL).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -318,13 +318,15 @@ class NvidiaWithOpenAiChatModelIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(NVIDIA_BASE_URL, System.getenv("NVIDIA_API_KEY"));
|
||||
return OpenAiApi.builder().baseUrl(NVIDIA_BASE_URL).apiKey(System.getenv("NVIDIA_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi,
|
||||
OpenAiChatOptions.builder().maxTokens(2048).model(DEFAULT_NVIDIA_MODEL).build());
|
||||
return OpenAiChatModel.builder()
|
||||
.openAiApi(openAiApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().maxTokens(2048).model(DEFAULT_NVIDIA_MODEL).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -327,14 +327,20 @@ class PerplexityWithOpenAiChatModelIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi chatCompletionApi() {
|
||||
return new OpenAiApi(PERPLEXITY_BASE_URL, System.getenv("PERPLEXITY_API_KEY"), PERPLEXITY_COMPLETIONS_PATH,
|
||||
"/v1/embeddings", RestClient.builder(), WebClient.builder(),
|
||||
RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
return OpenAiApi.builder()
|
||||
.baseUrl(PERPLEXITY_BASE_URL)
|
||||
.apiKey(System.getenv("PERPLEXITY_API_KEY"))
|
||||
.completionsPath(PERPLEXITY_COMPLETIONS_PATH)
|
||||
.embeddingsPath("/v1/embeddings")
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_PERPLEXITY_MODEL).build());
|
||||
return OpenAiChatModel.builder()
|
||||
.openAiApi(openAiApi)
|
||||
.defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_PERPLEXITY_MODEL).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ public class OpenAiEmbeddingModelObservationIT {
|
||||
|
||||
@Bean
|
||||
public OpenAiApi openAiApi() {
|
||||
return new OpenAiApi(System.getenv("OPENAI_API_KEY"));
|
||||
return OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -163,13 +163,12 @@ public class MetadataTransformerIT {
|
||||
throw new IllegalArgumentException(
|
||||
"You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY");
|
||||
}
|
||||
return new OpenAiApi(apiKey);
|
||||
return OpenAiApi.builder().apiKey(apiKey).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatModel openAiChatModel(OpenAiApi openAiApi) {
|
||||
OpenAiChatModel openAiChatModel = new OpenAiChatModel(openAiApi);
|
||||
return openAiChatModel;
|
||||
return OpenAiChatModel.builder().openAiApi(openAiApi).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.vertexai.gemini.tool;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.google.cloud.vertexai.Transport;
|
||||
import com.google.cloud.vertexai.VertexAI;
|
||||
import org.junit.jupiter.api.RepeatedTest;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
|
||||
import org.springframework.ai.model.function.FunctionCallback.SchemaType;
|
||||
import org.springframework.ai.model.function.FunctionCallbackResolver;
|
||||
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
|
||||
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Description;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_LOCATION", matches = ".*")
|
||||
@Deprecated
|
||||
public class VertexAiGeminiPaymentTransactionDeprecatedIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiPaymentTransactionDeprecatedIT.class);
|
||||
|
||||
private static final Map<Transaction, Status> DATASET = Map.of(new Transaction("001"), new Status("pending"),
|
||||
new Transaction("002"), new Status("approved"), new Transaction("003"), new Status("rejected"));
|
||||
|
||||
@Autowired
|
||||
ChatClient chatClient;
|
||||
|
||||
@Test
|
||||
public void paymentStatuses() {
|
||||
// @formatter:off
|
||||
String content = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.tools("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
If requred invoke the function per transaction.
|
||||
""").call().content();
|
||||
|
||||
logger.info("" + content);
|
||||
|
||||
assertThat(content).contains("001", "002", "003");
|
||||
assertThat(content).contains("pending", "approved", "rejected");
|
||||
}
|
||||
|
||||
@RepeatedTest(5)
|
||||
public void streamingPaymentStatuses() {
|
||||
|
||||
Flux<String> streamContent = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.tools("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
If requred invoke the function per transaction.
|
||||
""")
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
String content = streamContent.collectList().block().stream().collect(Collectors.joining());
|
||||
|
||||
logger.info(content);
|
||||
|
||||
assertThat(content).contains("001", "002", "003");
|
||||
assertThat(content).contains("pending", "approved", "rejected");
|
||||
|
||||
// Quota rate
|
||||
try {
|
||||
Thread.sleep(1000);
|
||||
}
|
||||
catch (InterruptedException e) {
|
||||
}
|
||||
}
|
||||
|
||||
record TransactionStatusResponse(String id, String status) {
|
||||
|
||||
}
|
||||
|
||||
private static class LoggingAdvisor implements CallAroundAdvisor {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
var response = chain.nextAroundCall(before(advisedRequest));
|
||||
observeAfter(response);
|
||||
return response;
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
logger.info("System text: \n" + request.systemText());
|
||||
logger.info("System params: " + request.systemParams());
|
||||
logger.info("User text: \n" + request.userText());
|
||||
logger.info("User params:" + request.userParams());
|
||||
logger.info("Function names: " + request.functionNames());
|
||||
|
||||
logger.info("Options: " + request.chatOptions().toString());
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
logger.info("Response: " + advisedResponse.response());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
record Transaction(String id) {
|
||||
}
|
||||
|
||||
record Status(String name) {
|
||||
}
|
||||
|
||||
record Transactions(List<Transaction> transactions) {
|
||||
}
|
||||
|
||||
record Statuses(List<Status> statuses) {
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
public static class TestConfiguration {
|
||||
|
||||
@Bean
|
||||
@Description("Get the status of a single payment transaction")
|
||||
public Function<Transaction, Status> paymentStatus() {
|
||||
return transaction -> {
|
||||
logger.info("Single Transaction: " + transaction);
|
||||
return DATASET.get(transaction);
|
||||
};
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Description("Get the list statuses of a list of payment transactions")
|
||||
public Function<Transactions, Statuses> paymentStatuses() {
|
||||
return transactions -> {
|
||||
logger.info("Transactions: " + transactions);
|
||||
return new Statuses(transactions.transactions().stream().map(t -> DATASET.get(t)).toList());
|
||||
};
|
||||
}
|
||||
|
||||
@Bean
|
||||
public ChatClient chatClient(VertexAiGeminiChatModel chatModel) {
|
||||
return ChatClient.builder(chatModel).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public VertexAI vertexAiApi() {
|
||||
|
||||
String projectId = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID");
|
||||
String location = System.getenv("VERTEX_AI_GEMINI_LOCATION");
|
||||
|
||||
return new VertexAI.Builder().setLocation(location)
|
||||
.setProjectId(projectId)
|
||||
.setTransport(Transport.REST)
|
||||
// .setTransport(Transport.GRPC)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public VertexAiGeminiChatModel vertexAiChatModel(VertexAI vertexAi, ApplicationContext context) {
|
||||
|
||||
FunctionCallbackResolver functionCallbackResolver = springAiFunctionManager(context);
|
||||
|
||||
return new VertexAiGeminiChatModel(vertexAi,
|
||||
VertexAiGeminiChatOptions.builder()
|
||||
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_1_5_FLASH)
|
||||
.temperature(0.1)
|
||||
.build(),
|
||||
functionCallbackResolver);
|
||||
}
|
||||
|
||||
/**
|
||||
* Because of the OPEN_API_SCHEMA type, the FunctionCallbackResolver instance
|
||||
* must
|
||||
* different from the other JSON schema types.
|
||||
*/
|
||||
private FunctionCallbackResolver springAiFunctionManager(ApplicationContext context) {
|
||||
DefaultFunctionCallbackResolver manager = new DefaultFunctionCallbackResolver();
|
||||
manager.setSchemaType(SchemaType.OPEN_API_SCHEMA);
|
||||
manager.setApplicationContext(context);
|
||||
return manager;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -30,19 +30,13 @@ import com.fasterxml.jackson.annotation.JsonPropertyOrder;
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "generationTokens", "nativeUsage" })
|
||||
@JsonPropertyOrder({ "promptTokens", "completionTokens", "totalTokens", "nativeUsage" })
|
||||
public class DefaultUsage implements Usage {
|
||||
|
||||
private final Integer promptTokens;
|
||||
|
||||
private final Integer completionTokens;
|
||||
|
||||
/**
|
||||
* @deprecated as of 1.0.0-M6, scheduled for removal
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
private final Long generationTokens;
|
||||
|
||||
private final int totalTokens;
|
||||
|
||||
private final Object nativeUsage;
|
||||
@@ -62,7 +56,6 @@ public class DefaultUsage implements Usage {
|
||||
public DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens, Object nativeUsage) {
|
||||
this.promptTokens = promptTokens != null ? promptTokens : 0;
|
||||
this.completionTokens = completionTokens != null ? completionTokens : 0;
|
||||
this.generationTokens = Long.valueOf(this.completionTokens);
|
||||
this.totalTokens = totalTokens != null ? totalTokens
|
||||
: calculateTotalTokens(this.promptTokens, this.completionTokens);
|
||||
this.nativeUsage = nativeUsage;
|
||||
@@ -106,11 +99,8 @@ public class DefaultUsage implements Usage {
|
||||
@JsonCreator
|
||||
public static DefaultUsage fromJson(@JsonProperty("promptTokens") Integer promptTokens,
|
||||
@JsonProperty("completionTokens") Integer completionTokens,
|
||||
@JsonProperty("generationTokens") Long generationTokens, @JsonProperty("totalTokens") Integer totalTokens,
|
||||
@JsonProperty("nativeUsage") Object nativeUsage) {
|
||||
Integer effectiveCompletionTokens = completionTokens != null ? completionTokens
|
||||
: (generationTokens != null ? generationTokens.intValue() : 0);
|
||||
return new DefaultUsage(promptTokens, effectiveCompletionTokens, totalTokens, nativeUsage);
|
||||
@JsonProperty("totalTokens") Integer totalTokens, @JsonProperty("nativeUsage") Object nativeUsage) {
|
||||
return new DefaultUsage(promptTokens, completionTokens, totalTokens, nativeUsage);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -34,11 +34,6 @@ public interface Usage {
|
||||
*/
|
||||
Integer getPromptTokens();
|
||||
|
||||
@Deprecated(forRemoval = true, since = "1.0.0-M6")
|
||||
default Long getGenerationTokens() {
|
||||
return getCompletionTokens().longValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the number of tokens returned in the {@literal generation (aka completion)}
|
||||
* of the AI's response.
|
||||
|
||||
@@ -32,13 +32,12 @@ public class DefaultUsageTests {
|
||||
void testSerializationWithAllFields() throws Exception {
|
||||
DefaultUsage usage = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150));
|
||||
String json = this.objectMapper.writeValueAsString(usage);
|
||||
assertThat(json)
|
||||
.isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}");
|
||||
assertThat(json).isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDeserializationWithAllFields() throws Exception {
|
||||
String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}";
|
||||
String json = "{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}";
|
||||
DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class);
|
||||
assertThat(usage.getPromptTokens()).isEqualTo(100);
|
||||
assertThat(usage.getCompletionTokens()).isEqualTo(50);
|
||||
@@ -49,8 +48,7 @@ public class DefaultUsageTests {
|
||||
void testSerializationWithNullFields() throws Exception {
|
||||
DefaultUsage usage = new DefaultUsage((Integer) null, (Integer) null, (Integer) null);
|
||||
String json = this.objectMapper.writeValueAsString(usage);
|
||||
assertThat(json)
|
||||
.isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}");
|
||||
assertThat(json).isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0}");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -92,8 +90,7 @@ public class DefaultUsageTests {
|
||||
|
||||
// Test serialization
|
||||
String json = this.objectMapper.writeValueAsString(usage);
|
||||
assertThat(json)
|
||||
.isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50}");
|
||||
assertThat(json).isEqualTo("{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150}");
|
||||
|
||||
// Test deserialization
|
||||
DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class);
|
||||
@@ -113,8 +110,7 @@ public class DefaultUsageTests {
|
||||
|
||||
// Test serialization
|
||||
String json = this.objectMapper.writeValueAsString(usage);
|
||||
assertThat(json)
|
||||
.isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0,\"generationTokens\":0}");
|
||||
assertThat(json).isEqualTo("{\"promptTokens\":0,\"completionTokens\":0,\"totalTokens\":0}");
|
||||
|
||||
// Test deserialization
|
||||
DefaultUsage deserializedUsage = this.objectMapper.readValue(json, DefaultUsage.class);
|
||||
@@ -123,18 +119,9 @@ public class DefaultUsageTests {
|
||||
assertThat(deserializedUsage.getTotalTokens()).isEqualTo(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDeserializationWithLegacyFormat() throws Exception {
|
||||
String json = "{\"promptTokens\":100,\"generationTokens\":50,\"totalTokens\":150}";
|
||||
DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class);
|
||||
assertThat(usage.getPromptTokens()).isEqualTo(100);
|
||||
assertThat(usage.getCompletionTokens()).isEqualTo(50);
|
||||
assertThat(usage.getTotalTokens()).isEqualTo(150);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDeserializationWithDifferentPropertyOrder() throws Exception {
|
||||
String json = "{\"totalTokens\":150,\"generationTokens\":50,\"completionTokens\":50,\"promptTokens\":100}";
|
||||
String json = "{\"totalTokens\":150,\"completionTokens\":50,\"promptTokens\":100}";
|
||||
DefaultUsage usage = this.objectMapper.readValue(json, DefaultUsage.class);
|
||||
assertThat(usage.getPromptTokens()).isEqualTo(100);
|
||||
assertThat(usage.getCompletionTokens()).isEqualTo(50);
|
||||
@@ -150,7 +137,7 @@ public class DefaultUsageTests {
|
||||
DefaultUsage usage = new DefaultUsage(100, 50, 150, customNativeUsage);
|
||||
String json = this.objectMapper.writeValueAsString(usage);
|
||||
assertThat(json).isEqualTo(
|
||||
"{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"generationTokens\":50,\"nativeUsage\":{\"custom_field\":\"custom_value\",\"custom_number\":42}}");
|
||||
"{\"promptTokens\":100,\"completionTokens\":50,\"totalTokens\":150,\"nativeUsage\":{\"custom_field\":\"custom_value\",\"custom_number\":42}}");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -195,14 +182,6 @@ public class DefaultUsageTests {
|
||||
assertThat(deserializedMap.get("field5")).isEqualTo(java.util.Map.of("nested", "value"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@SuppressWarnings("deprecation")
|
||||
void testDeprecatedGenerationTokens() {
|
||||
DefaultUsage usage = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150));
|
||||
assertThat(usage.getGenerationTokens()).isEqualTo(50L);
|
||||
assertThat(usage.getCompletionTokens().longValue()).isEqualTo(usage.getGenerationTokens());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testEqualsAndHashCode() {
|
||||
DefaultUsage usage1 = new DefaultUsage(Integer.valueOf(100), Integer.valueOf(50), Integer.valueOf(150));
|
||||
@@ -262,8 +241,7 @@ public class DefaultUsageTests {
|
||||
assertThat(usage.getTotalTokens()).isEqualTo(-3);
|
||||
|
||||
String json = this.objectMapper.writeValueAsString(usage);
|
||||
assertThat(json)
|
||||
.isEqualTo("{\"promptTokens\":-1,\"completionTokens\":-2,\"totalTokens\":-3,\"generationTokens\":-2}");
|
||||
assertThat(json).isEqualTo("{\"promptTokens\":-1,\"completionTokens\":-2,\"totalTokens\":-3}");
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -73,18 +73,6 @@ public class ToolCallingManagerTests {
|
||||
runExplicitToolCallingExecutionWithOptions(chatOptions, prompt);
|
||||
}
|
||||
|
||||
@Test
|
||||
void explicitToolCallingExecutionWithLegacyOptions() {
|
||||
ChatOptions chatOptions = FunctionCallingOptions.builder()
|
||||
.functionCallbacks(ToolCallbacks.from(tools))
|
||||
.proxyToolCalls(true)
|
||||
.build();
|
||||
Prompt prompt = new Prompt(
|
||||
new UserMessage("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")),
|
||||
chatOptions);
|
||||
runExplicitToolCallingExecutionWithOptions(chatOptions, prompt);
|
||||
}
|
||||
|
||||
@Test
|
||||
void explicitToolCallingExecutionWithNewOptionsStream() {
|
||||
ChatOptions chatOptions = ToolCallingChatOptions.builder()
|
||||
|
||||
@@ -35,7 +35,7 @@ public class MistralAiChatProperties extends MistralAiParentProperties {
|
||||
|
||||
public static final String CONFIG_PREFIX = "spring.ai.mistralai.chat";
|
||||
|
||||
public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue();
|
||||
public static final String DEFAULT_CHAT_MODEL = MistralAiApi.ChatModel.SMALL.getValue();
|
||||
|
||||
private static final Double DEFAULT_TEMPERATURE = 0.7;
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.minimax.MiniMaxChatModel;
|
||||
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
@@ -97,11 +98,9 @@ class FunctionCallbackWithPlainFunctionBeanIT {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
|
||||
|
||||
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
|
||||
.function("weatherFunction")
|
||||
.build();
|
||||
ToolCallingChatOptions toolOptions = ToolCallingChatOptions.builder().toolNames("weatherFunction").build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), toolOptions));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
});
|
||||
|
||||
@@ -34,6 +34,7 @@ import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.moonshot.MoonshotChatModel;
|
||||
import org.springframework.ai.moonshot.MoonshotChatOptions;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
@@ -98,11 +99,9 @@ class FunctionCallbackWithPlainFunctionBeanIT {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius");
|
||||
|
||||
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
|
||||
.function("weatherFunction")
|
||||
.build();
|
||||
ToolCallingChatOptions toolOptions = ToolCallingChatOptions.builder().toolNames("weatherFunction").build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), toolOptions));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
});
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
package org.springframework.ai.autoconfigure.vertexai.gemini.tool;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -29,6 +30,7 @@ import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatModel;
|
||||
import org.springframework.ai.vertexai.gemini.VertexAiGeminiChatOptions;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
@@ -109,14 +111,14 @@ class FunctionCallWithFunctionBeanIT {
|
||||
""");
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
|
||||
FunctionCallingOptions.builder().function("weatherFunction").build()));
|
||||
ToolCallingChatOptions.builder().toolNames("weatherFunction").build()));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
|
||||
|
||||
response = chatModel.call(new Prompt(List.of(userMessage),
|
||||
VertexAiGeminiChatOptions.builder().function("weatherFunction3").build()));
|
||||
VertexAiGeminiChatOptions.builder().toolNames(Set.of("weatherFunction3")).build()));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
@@ -97,8 +98,8 @@ class FunctionCallbackWithPlainFunctionBeanIT {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
|
||||
|
||||
FunctionCallingOptions functionOptions = FunctionCallingOptions.builder()
|
||||
.function("weatherFunction")
|
||||
ToolCallingChatOptions functionOptions = ToolCallingChatOptions.builder()
|
||||
.toolNames("weatherFunction")
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), functionOptions));
|
||||
|
||||
@@ -100,7 +100,7 @@ class MongoDbAtlasLocalContainerConnectionDetailsFactoryIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ public class BasicAuthChromaWhereIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ public class ChromaVectorStoreIT extends BaseVectorStoreTests {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -186,7 +186,7 @@ public class ChromaVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ public class TokenSecuredChromaWhereIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -505,7 +505,7 @@ class ElasticsearchVectorStoreIT extends BaseVectorStoreTests {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -220,7 +220,7 @@ public class ElasticsearchVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -128,7 +128,7 @@ public class HanaCloudVectorStoreIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -206,7 +206,7 @@ public class HanaVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ public class MariaDBStoreCustomNamesIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -451,7 +451,7 @@ public class MariaDBStoreIT extends BaseVectorStoreTests {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -203,7 +203,7 @@ public class MariaDBStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ class MilvusVectorStoreCustomFieldNamesIT {
|
||||
|
||||
@Bean
|
||||
EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -361,7 +361,7 @@ public class MilvusVectorStoreIT extends BaseVectorStoreTests {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
// return new OpenAiEmbeddingModel(new
|
||||
// OpenAiApi(System.getenv("OPENAI_API_KEY")), MetadataMode.EMBED,
|
||||
// OpenAiEmbeddingOptions.builder().withModel("text-embedding-ada-002").build());
|
||||
|
||||
@@ -190,7 +190,7 @@ public class MilvusVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -345,7 +345,7 @@ class MongoDBAtlasVectorStoreIT extends BaseVectorStoreTests {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -209,7 +209,7 @@ public class MongoDbVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -374,7 +374,7 @@ class Neo4jVectorStoreIT extends BaseVectorStoreTests {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ public class Neo4jVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -602,7 +602,7 @@ class OpenSearchVectorStoreIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ public class OpenSearchVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -245,7 +245,7 @@ public class PgVectorStoreCustomNamesIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -483,7 +483,7 @@ public class PgVectorStoreIT extends BaseVectorStoreTests {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ public class PgVectorStoreObservationIT {
|
||||
|
||||
@Bean
|
||||
public EmbeddingModel embeddingModel() {
|
||||
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
||||
return new OpenAiEmbeddingModel(OpenAiApi.builder().apiKey(System.getenv("OPENAI_API_KEY")).build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user