ChatClinet Flunet DSL and renaming [Model/Chat/Embedding/Image/Speech]Client to Model, [Chat/Image/Embedding/Image/Sppec]Model
- copy constructor for ChatClient#call incantations based on defaults. - copy constructor for ChatClient#call incantations based on defaults. - fix ChatClient structured output flow. - add OpenAiChatClientIT. - add ChatClient stream support - implement fromOptions copty factory for eavery chatoptions implementation. - extend ChatClient to use the caller default options if not provided explicitely. - fix system/user text overdidign default system/user texts. Only non empty user/system text can overrid defult system/user text. - rename chat() method to collect(). - add OpenAI FunctionCallbackWrapper2IT auto-config tests. - rename request call() to prompt() and the response collect() to call(). Adjust tests - add overload methods for defaultSytem()/defaultUer() and system() methods. - add ChatClientTest mockito testing - fix ChatClient list() convertions - Rename the ModelClient class hiearchy into Model - Rename ModelClient into Model. Update all code and doc references. - Rename ChatClient to ChatModel. Update all ChatClient suffixes and chatClient fields and variables in code and doc. - Rename EmbeddingClient into EmbeddingModel. Update the XxxEmbeddingClient class and variable suffixes and embeddingClient variables and fields in code and docs. - Rename ImageClient into ImageModel. .... - Rename SpeechClient into SpeechModel .... - Rename TranscriptionClient into TranscriptionModel ... - Update all javadocs and antora pages. Update the related diagrams. - refactor usability tests. Co-authored-by: Christian Tzolov <ctzolov@vmware.com>
This commit is contained in:
committed by
Christian Tzolov
parent
bce45c2d2f
commit
57615b6303
@@ -26,6 +26,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi;
|
||||
@@ -38,10 +39,9 @@ import org.springframework.ai.anthropic.api.AnthropicApi.Role;
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi.StreamResponse;
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi.Usage;
|
||||
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
@@ -56,16 +56,16 @@ import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* The {@link ChatClient} implementation for the Anthropic service.
|
||||
* The {@link ChatModel} implementation for the Anthropic service.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class AnthropicChatClient extends
|
||||
public class AnthropicChatModel extends
|
||||
AbstractFunctionCallSupport<AnthropicApi.RequestMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletion>>
|
||||
implements ChatClient, StreamingChatClient {
|
||||
implements ChatModel, StreamingChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatClient.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);
|
||||
|
||||
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue();
|
||||
|
||||
@@ -89,10 +89,10 @@ public class AnthropicChatClient extends
|
||||
public final RetryTemplate retryTemplate;
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatClient} instance.
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
*/
|
||||
public AnthropicChatClient(AnthropicApi anthropicApi) {
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi) {
|
||||
this(anthropicApi,
|
||||
AnthropicChatOptions.builder()
|
||||
.withModel(DEFAULT_MODEL_NAME)
|
||||
@@ -102,34 +102,34 @@ public class AnthropicChatClient extends
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatClient} instance.
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
*/
|
||||
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) {
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) {
|
||||
this(anthropicApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatClient} instance.
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
* @param retryTemplate the retry template used to retry the Anthropic API calls.
|
||||
*/
|
||||
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
RetryTemplate retryTemplate) {
|
||||
this(anthropicApi, defaultOptions, retryTemplate, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a new {@link AnthropicChatClient} instance.
|
||||
* Construct a new {@link AnthropicChatModel} instance.
|
||||
* @param anthropicApi the lower-level API for the Anthropic service.
|
||||
* @param defaultOptions the default options used for the chat completion requests.
|
||||
* @param retryTemplate the retry template used to retry the Anthropic API calls.
|
||||
* @param functionCallbackContext the function callback context used to store the
|
||||
* state of the function calls.
|
||||
*/
|
||||
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
|
||||
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext) {
|
||||
|
||||
super(functionCallbackContext);
|
||||
@@ -457,4 +457,9 @@ public class AnthropicChatClient extends
|
||||
"Streaming (stream=true) is not yet supported. We plan to add streaming support in a future beta version.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return AnthropicChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -51,11 +51,11 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
|
||||
private @JsonProperty("top_k") Integer topK;
|
||||
|
||||
/**
|
||||
* Tool Function Callbacks to register with the ChatClient. For Prompt
|
||||
* Tool Function Callbacks to register with the ChatModel. For Prompt
|
||||
* Options the functionCallbacks are automatically enabled for the duration of the
|
||||
* prompt execution. For Default Options the functionCallbacks are registered but
|
||||
* disabled by default. Use the enableFunctions to set the functions from the registry
|
||||
* to be used by the ChatClient chat completion requests.
|
||||
* to be used by the ChatModel chat completion requests.
|
||||
*/
|
||||
@NestedConfigurationProperty
|
||||
@JsonIgnore
|
||||
@@ -223,4 +223,17 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
|
||||
this.functions = functions;
|
||||
}
|
||||
|
||||
public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) {
|
||||
return builder().withModel(fromOptions.getModel())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withMetadata(fromOptions.getMetadata())
|
||||
.withStopSequences(fromOptions.getStopSequences())
|
||||
.withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withTopK(fromOptions.getTopK())
|
||||
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
|
||||
.withFunctions(fromOptions.getFunctions())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -116,7 +117,7 @@ public class AnthropicApi {
|
||||
* "https://docs.anthropic.com/claude/docs/models-overview#model-comparison">model
|
||||
* comparison</a> for additional details and options.
|
||||
*/
|
||||
public enum ChatModel {
|
||||
public enum ChatModel implements ModelDescription {
|
||||
|
||||
// @formatter:off
|
||||
CLAUDE_3_OPUS("claude-3-opus-20240229"),
|
||||
@@ -140,6 +141,11 @@ public class AnthropicApi {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -29,10 +29,10 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.anthropic.api.AnthropicApi;
|
||||
import org.springframework.ai.anthropic.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
@@ -56,15 +56,15 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429")
|
||||
@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+")
|
||||
class AnthropicChatClientIT {
|
||||
class AnthropicChatModelIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatClientIT.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModelIT.class);
|
||||
|
||||
@Autowired
|
||||
protected ChatClient chatClient;
|
||||
protected ChatModel chatModel;
|
||||
|
||||
@Autowired
|
||||
protected StreamingChatClient streamingChatClient;
|
||||
protected StreamingChatModel streamingChatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -76,7 +76,7 @@ class AnthropicChatClientIT {
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
ChatResponse response = chatClient.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0);
|
||||
assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0);
|
||||
@@ -102,7 +102,7 @@ class AnthropicChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatClient.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = listOutputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -120,7 +120,7 @@ class AnthropicChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = mapOutputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -142,7 +142,7 @@ class AnthropicChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = beanOutputConverter.convert(generation.getOutput().getContent());
|
||||
logger.info("" + actorsFilms);
|
||||
@@ -163,7 +163,7 @@ class AnthropicChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = streamingChatClient.stream(prompt)
|
||||
String generationTextFromStream = streamingChatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -187,7 +187,7 @@ class AnthropicChatClientIT {
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?",
|
||||
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
|
||||
|
||||
var response = chatClient.call(new Prompt(List.of(userMessage)));
|
||||
var response = chatModel.call(new Prompt(List.of(userMessage)));
|
||||
|
||||
logger.info(response.getResult().getOutput().getContent());
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
|
||||
@@ -209,7 +209,7 @@ class AnthropicChatClientIT {
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
|
||||
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
@@ -38,9 +38,9 @@ public class AnthropicTestConfiguration {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public AnthropicChatClient openAiChatClient(AnthropicApi api) {
|
||||
AnthropicChatClient anthropicChatClient = new AnthropicChatClient(api);
|
||||
return anthropicChatClient;
|
||||
public AnthropicChatModel openAiChatModel(AnthropicApi api) {
|
||||
AnthropicChatModel anthropicChatModel = new AnthropicChatModel(api);
|
||||
return anthropicChatModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ public class ChatCompletionRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new AnthropicChatClient(new AnthropicApi("TEST"),
|
||||
var client = new AnthropicChatModel(new AnthropicApi("TEST"),
|
||||
AnthropicChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
|
||||
|
||||
var request = client.createRequest(new Prompt("Test message content"), false);
|
||||
|
||||
@@ -38,10 +38,10 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.PromptMetadata;
|
||||
@@ -63,7 +63,7 @@ import java.util.Set;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
|
||||
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
|
||||
* {@link OpenAIClient}.
|
||||
*
|
||||
* @author Mark Pollack
|
||||
@@ -71,12 +71,12 @@ import java.util.concurrent.atomic.AtomicBoolean;
|
||||
* @author John Blum
|
||||
* @author Christian Tzolov
|
||||
* @author Grogdunn
|
||||
* @see ChatClient
|
||||
* @see ChatModel
|
||||
* @see com.azure.ai.openai.OpenAIClient
|
||||
*/
|
||||
public class AzureOpenAiChatClient
|
||||
public class AzureOpenAiChatModel
|
||||
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
|
||||
implements ChatClient, StreamingChatClient {
|
||||
implements ChatModel, StreamingChatModel {
|
||||
|
||||
private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";
|
||||
|
||||
@@ -94,7 +94,7 @@ public class AzureOpenAiChatClient
|
||||
*/
|
||||
private final OpenAIClient openAIClient;
|
||||
|
||||
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
|
||||
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
|
||||
this(microsoftOpenAiClient,
|
||||
AzureOpenAiChatOptions.builder()
|
||||
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
|
||||
@@ -102,11 +102,11 @@ public class AzureOpenAiChatClient
|
||||
.build());
|
||||
}
|
||||
|
||||
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
|
||||
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
|
||||
this(microsoftOpenAiClient, options, null);
|
||||
}
|
||||
|
||||
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
|
||||
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
|
||||
FunctionCallbackContext functionCallbackContext) {
|
||||
super(functionCallbackContext);
|
||||
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
|
||||
@@ -117,17 +117,17 @@ public class AzureOpenAiChatClient
|
||||
|
||||
/**
|
||||
* @deprecated since 0.8.0, use
|
||||
* {@link #AzureOpenAiChatClient(OpenAIClient, AzureOpenAiChatOptions)} instead.
|
||||
* {@link #AzureOpenAiChatModel(OpenAIClient, AzureOpenAiChatOptions)} instead.
|
||||
*/
|
||||
@Deprecated(forRemoval = true, since = "0.8.0")
|
||||
public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
|
||||
public AzureOpenAiChatModel withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
|
||||
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
|
||||
this.defaultOptions = defaultOptions;
|
||||
return this;
|
||||
}
|
||||
|
||||
public AzureOpenAiChatOptions getDefaultOptions() {
|
||||
return this.defaultOptions;
|
||||
return AzureOpenAiChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -127,11 +127,11 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
|
||||
private String deploymentName;
|
||||
|
||||
/**
|
||||
* OpenAI Tool Function Callbacks to register with the ChatClient. For Prompt Options
|
||||
* OpenAI Tool Function Callbacks to register with the ChatModel. For Prompt Options
|
||||
* the functionCallbacks are automatically enabled for the duration of the prompt
|
||||
* execution. For Default Options the functionCallbacks are registered but disabled by
|
||||
* default. Use the enableFunctions to set the functions from the registry to be used
|
||||
* by the ChatClient chat completion requests.
|
||||
* by the ChatModel chat completion requests.
|
||||
*/
|
||||
@NestedConfigurationProperty
|
||||
@JsonIgnore
|
||||
@@ -356,4 +356,22 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
|
||||
this.functions = functions;
|
||||
}
|
||||
|
||||
public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) {
|
||||
return builder().withDeploymentName(fromOptions.getDeploymentName())
|
||||
.withFrequencyPenalty(
|
||||
fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty().floatValue() : null)
|
||||
.withLogitBias(fromOptions.getLogitBias())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withN(fromOptions.getN())
|
||||
.withPresencePenalty(
|
||||
fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty().floatValue() : null)
|
||||
.withStop(fromOptions.getStop())
|
||||
.withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withUser(fromOptions.getUser())
|
||||
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
|
||||
.withFunctions(fromOptions.getFunctions())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingClient;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
@@ -37,9 +37,9 @@ import org.springframework.ai.embedding.EmbeddingResponseMetadata;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {
|
||||
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingClient.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class);
|
||||
|
||||
private final OpenAIClient azureOpenAiClient;
|
||||
|
||||
@@ -47,16 +47,16 @@ public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {
|
||||
|
||||
private final MetadataMode metadataMode;
|
||||
|
||||
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient) {
|
||||
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient) {
|
||||
this(azureOpenAiClient, MetadataMode.EMBED);
|
||||
}
|
||||
|
||||
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) {
|
||||
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) {
|
||||
this(azureOpenAiClient, metadataMode,
|
||||
AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build());
|
||||
}
|
||||
|
||||
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
|
||||
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
|
||||
AzureOpenAiEmbeddingOptions options) {
|
||||
Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
|
||||
Assert.notNull(metadataMode, "Metadata mode must not be null");
|
||||
@@ -53,7 +53,7 @@ public class AzureChatCompletionsOptionsTests {
|
||||
.withUser("user")
|
||||
.build();
|
||||
|
||||
var client = new AzureOpenAiChatClient(mockClient, defaultOptions);
|
||||
var client = new AzureOpenAiChatModel(mockClient, defaultOptions);
|
||||
|
||||
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ public class AzureEmbeddingsOptionsTests {
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
|
||||
var client = new AzureOpenAiEmbeddingClient(mockClient, MetadataMode.EMBED,
|
||||
var client = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED,
|
||||
AzureOpenAiEmbeddingOptions.builder()
|
||||
.withDeploymentName("DEFAULT_MODEL")
|
||||
.withUser("USER_TEST")
|
||||
|
||||
@@ -46,13 +46,13 @@ import org.springframework.core.convert.support.DefaultConversionService;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = AzureOpenAiChatClientIT.TestConfiguration.class)
|
||||
@SpringBootTest(classes = AzureOpenAiChatModelIT.TestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
|
||||
class AzureOpenAiChatClientIT {
|
||||
class AzureOpenAiChatModelIT {
|
||||
|
||||
@Autowired
|
||||
private AzureOpenAiChatClient chatClient;
|
||||
private AzureOpenAiChatModel chatModel;
|
||||
|
||||
record ActorsFilms(String actor, List<String> movies) {
|
||||
}
|
||||
@@ -69,7 +69,7 @@ class AzureOpenAiChatClientIT {
|
||||
UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates.");
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
ChatResponse response = chatClient.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
|
||||
@@ -86,7 +86,7 @@ class AzureOpenAiChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -105,7 +105,7 @@ class AzureOpenAiChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -124,7 +124,7 @@ class AzureOpenAiChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(actorsFilms.actor()).isNotNull();
|
||||
@@ -145,7 +145,7 @@ class AzureOpenAiChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
System.out.println(actorsFilms);
|
||||
@@ -166,7 +166,7 @@ class AzureOpenAiChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = chatClient.stream(prompt)
|
||||
String generationTextFromStream = chatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -194,8 +194,8 @@ class AzureOpenAiChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
|
||||
return new AzureOpenAiChatClient(openAIClient,
|
||||
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) {
|
||||
return new AzureOpenAiChatModel(openAIClient,
|
||||
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(200).build());
|
||||
|
||||
}
|
||||
@@ -34,25 +34,25 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
|
||||
class AzureOpenAiEmbeddingClientIT {
|
||||
class AzureOpenAiEmbeddingModelIT {
|
||||
|
||||
@Autowired
|
||||
private AzureOpenAiEmbeddingClient embeddingClient;
|
||||
private AzureOpenAiEmbeddingModel embeddingModel;
|
||||
|
||||
@Test
|
||||
void singleEmbedding() {
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World"));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
|
||||
System.out.println(embeddingClient.dimensions());
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
|
||||
System.out.println(embeddingModel.dimensions());
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(1536);
|
||||
}
|
||||
|
||||
@Test
|
||||
void batchEmbedding() {
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingClient
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingModel
|
||||
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(2);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
|
||||
@@ -60,7 +60,7 @@ class AzureOpenAiEmbeddingClientIT {
|
||||
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
|
||||
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
|
||||
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(1536);
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
@@ -74,8 +74,8 @@ class AzureOpenAiEmbeddingClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public AzureOpenAiEmbeddingClient azureEmbeddingClient(OpenAIClient openAIClient) {
|
||||
return new AzureOpenAiEmbeddingClient(openAIClient);
|
||||
public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient) {
|
||||
return new AzureOpenAiEmbeddingModel(openAIClient);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -59,8 +59,8 @@ public class MockAzureOpenAiTestConfiguration {
|
||||
}
|
||||
|
||||
@Bean
|
||||
AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient microsoftAzureOpenAiClient) {
|
||||
return new AzureOpenAiChatClient(microsoftAzureOpenAiClient);
|
||||
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) {
|
||||
return new AzureOpenAiChatModel(microsoftAzureOpenAiClient);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatClient;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
@@ -46,18 +46,18 @@ import reactor.core.publisher.Flux;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = AzureOpenAiChatClientFunctionCallIT.TestConfiguration.class)
|
||||
@SpringBootTest(classes = AzureOpenAiChatModelFunctionCallIT.TestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
|
||||
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
|
||||
class AzureOpenAiChatClientFunctionCallIT {
|
||||
class AzureOpenAiChatModelFunctionCallIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatClientFunctionCallIT.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelFunctionCallIT.class);
|
||||
|
||||
@Autowired
|
||||
private String selectedModel;
|
||||
|
||||
@Autowired
|
||||
private AzureOpenAiChatClient chatClient;
|
||||
private AzureOpenAiChatModel chatModel;
|
||||
|
||||
@Test
|
||||
void functionCallTest() {
|
||||
@@ -75,7 +75,7 @@ class AzureOpenAiChatClientFunctionCallIT {
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
|
||||
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
@@ -99,7 +99,7 @@ class AzureOpenAiChatClientFunctionCallIT {
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
Flux<ChatResponse> response = chatClient.stream(new Prompt(messages, promptOptions));
|
||||
Flux<ChatResponse> response = chatModel.stream(new Prompt(messages, promptOptions));
|
||||
|
||||
final var counter = new AtomicInteger();
|
||||
String content = response.doOnEach(listSignal -> counter.getAndIncrement())
|
||||
@@ -129,8 +129,8 @@ class AzureOpenAiChatClientFunctionCallIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient, String selectedModel) {
|
||||
return new AzureOpenAiChatClient(openAIClient,
|
||||
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) {
|
||||
return new AzureOpenAiChatModel(openAIClient,
|
||||
AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build());
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ import com.azure.ai.openai.models.ContentFilterResultsForChoice;
|
||||
import com.azure.ai.openai.models.ContentFilterSeverity;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatClient;
|
||||
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
|
||||
import org.springframework.ai.azure.openai.MockAzureOpenAiTestConfiguration;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
@@ -55,7 +55,7 @@ import org.springframework.web.context.request.WebRequest;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Unit Tests for {@link AzureOpenAiChatClient} asserting AI metadata.
|
||||
* Unit Tests for {@link AzureOpenAiChatModel} asserting AI metadata.
|
||||
*
|
||||
* @author John Blum
|
||||
* @author Christian Tzolov
|
||||
@@ -63,12 +63,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
*/
|
||||
@SpringBootTest
|
||||
@ActiveProfiles("spring-ai-azure-openai-mocks")
|
||||
@ContextConfiguration(classes = AzureOpenAiChatClientMetadataTests.TestConfiguration.class)
|
||||
@ContextConfiguration(classes = AzureOpenAiChatModelMetadataTests.TestConfiguration.class)
|
||||
@SuppressWarnings("unused")
|
||||
class AzureOpenAiChatClientMetadataTests {
|
||||
class AzureOpenAiChatModelMetadataTests {
|
||||
|
||||
@Autowired
|
||||
private AzureOpenAiChatClient aiClient;
|
||||
private AzureOpenAiChatModel aiClient;
|
||||
|
||||
@Test
|
||||
void azureOpenAiMetadataCapturedDuringGeneration() {
|
||||
@@ -164,4 +164,14 @@ public class AnthropicChatOptions implements ChatOptions {
|
||||
this.anthropicVersion = anthropicVersion;
|
||||
}
|
||||
|
||||
public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) {
|
||||
return builder().withTemperature(fromOptions.getTemperature())
|
||||
.withMaxTokensToSample(fromOptions.getMaxTokensToSample())
|
||||
.withTopK(fromOptions.getTopK())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withStopSequences(fromOptions.getStopSequences())
|
||||
.withAnthropicVersion(fromOptions.getAnthropicVersion())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ package org.springframework.ai.bedrock.anthropic;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
@@ -27,25 +27,25 @@ import org.springframework.ai.bedrock.MessageToPromptConverter;
|
||||
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
|
||||
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
|
||||
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
|
||||
/**
|
||||
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat
|
||||
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat
|
||||
* generative.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClient {
|
||||
public class BedrockAnthropicChatModel implements ChatModel, StreamingChatModel {
|
||||
|
||||
private final AnthropicChatBedrockApi anthropicChatApi;
|
||||
|
||||
private final AnthropicChatOptions defaultOptions;
|
||||
|
||||
public BedrockAnthropicChatClient(AnthropicChatBedrockApi chatApi) {
|
||||
public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi) {
|
||||
this(chatApi,
|
||||
AnthropicChatOptions.builder()
|
||||
.withTemperature(0.8f)
|
||||
@@ -55,7 +55,7 @@ public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClie
|
||||
.build());
|
||||
}
|
||||
|
||||
public BedrockAnthropicChatClient(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options) {
|
||||
public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options) {
|
||||
this.anthropicChatApi = chatApi;
|
||||
this.defaultOptions = options;
|
||||
}
|
||||
@@ -117,4 +117,9 @@ public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClie
|
||||
return request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return AnthropicChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -29,6 +29,7 @@ import software.amazon.awssdk.regions.Region;
|
||||
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
|
||||
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
|
||||
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
@@ -225,7 +226,7 @@ public class AnthropicChatBedrockApi extends
|
||||
/**
|
||||
* Anthropic models version.
|
||||
*/
|
||||
public enum AnthropicChatModel {
|
||||
public enum AnthropicChatModel implements ModelDescription {
|
||||
/**
|
||||
* anthropic.claude-instant-v1
|
||||
*/
|
||||
@@ -251,6 +252,11 @@ public class AnthropicChatBedrockApi extends
|
||||
AnthropicChatModel(String value) {
|
||||
this.id = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.id;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -163,4 +163,14 @@ public class Anthropic3ChatOptions implements ChatOptions {
|
||||
this.anthropicVersion = anthropicVersion;
|
||||
}
|
||||
|
||||
public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
|
||||
return builder().withTemperature(fromOptions.getTemperature())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withTopK(fromOptions.getTopK())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withStopSequences(fromOptions.getStopSequences())
|
||||
.withAnthropicVersion(fromOptions.getAnthropicVersion())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,17 +15,25 @@
|
||||
*/
|
||||
package org.springframework.ai.bedrock.anthropic3;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
@@ -34,29 +42,21 @@ import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat
|
||||
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat
|
||||
* generative.
|
||||
*
|
||||
* @author Ben Middleton
|
||||
* @author Christian Tzolov
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatClient {
|
||||
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
|
||||
|
||||
private final Anthropic3ChatBedrockApi anthropicChatApi;
|
||||
|
||||
private final Anthropic3ChatOptions defaultOptions;
|
||||
|
||||
public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi) {
|
||||
public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) {
|
||||
this(chatApi,
|
||||
Anthropic3ChatOptions.builder()
|
||||
.withTemperature(0.8f)
|
||||
@@ -66,7 +66,7 @@ public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatCli
|
||||
.build());
|
||||
}
|
||||
|
||||
public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) {
|
||||
public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) {
|
||||
this.anthropicChatApi = chatApi;
|
||||
this.defaultOptions = options;
|
||||
}
|
||||
@@ -187,4 +187,9 @@ public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatCli
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return Anthropic3ChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.An
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
|
||||
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse;
|
||||
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
import org.springframework.util.Assert;
|
||||
import reactor.core.publisher.Flux;
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
@@ -436,7 +437,7 @@ public class Anthropic3ChatBedrockApi extends
|
||||
/**
|
||||
* Anthropic models version.
|
||||
*/
|
||||
public enum AnthropicChatModel {
|
||||
public enum AnthropicChatModel implements ModelDescription {
|
||||
|
||||
/**
|
||||
* anthropic.claude-instant-v1
|
||||
@@ -476,6 +477,11 @@ public class Anthropic3ChatBedrockApi extends
|
||||
this.id = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -24,11 +24,11 @@ import org.springframework.ai.bedrock.MessageToPromptConverter;
|
||||
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
|
||||
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest;
|
||||
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
@@ -39,17 +39,17 @@ import org.springframework.util.Assert;
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class BedrockCohereChatClient implements ChatClient, StreamingChatClient {
|
||||
public class BedrockCohereChatModel implements ChatModel, StreamingChatModel {
|
||||
|
||||
private final CohereChatBedrockApi chatApi;
|
||||
|
||||
private final BedrockCohereChatOptions defaultOptions;
|
||||
|
||||
public BedrockCohereChatClient(CohereChatBedrockApi chatApi) {
|
||||
public BedrockCohereChatModel(CohereChatBedrockApi chatApi) {
|
||||
this(chatApi, BedrockCohereChatOptions.builder().build());
|
||||
}
|
||||
|
||||
public BedrockCohereChatClient(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) {
|
||||
public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) {
|
||||
Assert.notNull(chatApi, "CohereChatBedrockApi must not be null");
|
||||
Assert.notNull(options, "BedrockCohereChatOptions must not be null");
|
||||
|
||||
@@ -114,4 +114,9 @@ public class BedrockCohereChatClient implements ChatClient, StreamingChatClient
|
||||
return request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return BedrockCohereChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -213,4 +213,17 @@ public class BedrockCohereChatOptions implements ChatOptions {
|
||||
this.truncate = truncate;
|
||||
}
|
||||
|
||||
public static BedrockCohereChatOptions fromOptions(BedrockCohereChatOptions fromOptions) {
|
||||
return builder().withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withTopK(fromOptions.getTopK())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withStopSequences(fromOptions.getStopSequences())
|
||||
.withReturnLikelihoods(fromOptions.getReturnLikelihoods())
|
||||
.withNumGenerations(fromOptions.getNumGenerations())
|
||||
.withLogitBias(fromOptions.getLogitBias())
|
||||
.withTruncate(fromOptions.getTruncate())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
|
||||
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
|
||||
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingResponse;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingClient;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
@@ -31,14 +31,14 @@ import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* {@link org.springframework.ai.embedding.EmbeddingClient} implementation that uses the
|
||||
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
|
||||
* Bedrock Cohere Embedding API. Note: The invocation metrics are not exposed by AWS for
|
||||
* this API. If this change in the future we will add it as metadata.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
|
||||
public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private final CohereEmbeddingBedrockApi embeddingApi;
|
||||
|
||||
@@ -50,7 +50,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
|
||||
// private CohereEmbeddingRequest.Truncate truncate =
|
||||
// CohereEmbeddingRequest.Truncate.NONE;
|
||||
|
||||
public BedrockCohereEmbeddingClient(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi) {
|
||||
public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi) {
|
||||
this(cohereEmbeddingBedrockApi,
|
||||
BedrockCohereEmbeddingOptions.builder()
|
||||
.withInputType(CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT)
|
||||
@@ -58,7 +58,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
|
||||
.build());
|
||||
}
|
||||
|
||||
public BedrockCohereEmbeddingClient(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
|
||||
public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
|
||||
BedrockCohereEmbeddingOptions options) {
|
||||
Assert.notNull(cohereEmbeddingBedrockApi, "CohereEmbeddingBedrockApi must not be null");
|
||||
Assert.notNull(options, "BedrockCohereEmbeddingOptions must not be null");
|
||||
@@ -71,7 +71,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
|
||||
// * @param inputType the input type to use.
|
||||
// * @return this client.
|
||||
// */
|
||||
// public BedrockCohereEmbeddingClient withInputType(CohereEmbeddingRequest.InputType
|
||||
// public BedrockCohereEmbeddingModel withInputType(CohereEmbeddingRequest.InputType
|
||||
// inputType) {
|
||||
// this.inputType = inputType;
|
||||
// return this;
|
||||
@@ -85,7 +85,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
|
||||
// * @param truncate the truncate option to use.
|
||||
// * @return this client.
|
||||
// */
|
||||
// public BedrockCohereEmbeddingClient withTruncate(CohereEmbeddingRequest.Truncate
|
||||
// public BedrockCohereEmbeddingModel withTruncate(CohereEmbeddingRequest.Truncate
|
||||
// truncate) {
|
||||
// this.truncate = truncate;
|
||||
// return this;
|
||||
@@ -30,6 +30,7 @@ import software.amazon.awssdk.regions.Region;
|
||||
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
|
||||
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest;
|
||||
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse;
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
@@ -366,7 +367,7 @@ public class CohereChatBedrockApi extends
|
||||
/**
|
||||
* Cohere models version.
|
||||
*/
|
||||
public enum CohereChatModel {
|
||||
public enum CohereChatModel implements ModelDescription {
|
||||
|
||||
/**
|
||||
* cohere.command-light-text-v14
|
||||
@@ -390,6 +391,11 @@ public class CohereChatBedrockApi extends
|
||||
CohereChatModel(String value) {
|
||||
this.id = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.id;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -19,7 +19,7 @@ package org.springframework.ai.bedrock.jurassic2;
|
||||
import org.springframework.ai.bedrock.MessageToPromptConverter;
|
||||
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
|
||||
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
@@ -29,19 +29,18 @@ import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Java {@link ChatClient} for the Bedrock Jurassic2 chat generative model.
|
||||
* Java {@link ChatModel} for the Bedrock Jurassic2 chat generative model.
|
||||
*
|
||||
* @author Ahmed Yousri
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class BedrockAi21Jurassic2ChatClient implements ChatClient {
|
||||
public class BedrockAi21Jurassic2ChatModel implements ChatModel {
|
||||
|
||||
private final Ai21Jurassic2ChatBedrockApi chatApi;
|
||||
|
||||
private final BedrockAi21Jurassic2ChatOptions defaultOptions;
|
||||
|
||||
public BedrockAi21Jurassic2ChatClient(Ai21Jurassic2ChatBedrockApi chatApi,
|
||||
BedrockAi21Jurassic2ChatOptions options) {
|
||||
public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi, BedrockAi21Jurassic2ChatOptions options) {
|
||||
Assert.notNull(chatApi, "Ai21Jurassic2ChatBedrockApi must not be null");
|
||||
Assert.notNull(options, "BedrockAi21Jurassic2ChatOptions must not be null");
|
||||
|
||||
@@ -49,7 +48,7 @@ public class BedrockAi21Jurassic2ChatClient implements ChatClient {
|
||||
this.defaultOptions = options;
|
||||
}
|
||||
|
||||
public BedrockAi21Jurassic2ChatClient(Ai21Jurassic2ChatBedrockApi chatApi) {
|
||||
public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi) {
|
||||
this(chatApi,
|
||||
BedrockAi21Jurassic2ChatOptions.builder()
|
||||
.withTemperature(0.8f)
|
||||
@@ -114,11 +113,16 @@ public class BedrockAi21Jurassic2ChatClient implements ChatClient {
|
||||
return this;
|
||||
}
|
||||
|
||||
public BedrockAi21Jurassic2ChatClient build() {
|
||||
return new BedrockAi21Jurassic2ChatClient(chatApi,
|
||||
public BedrockAi21Jurassic2ChatModel build() {
|
||||
return new BedrockAi21Jurassic2ChatModel(chatApi,
|
||||
options != null ? options : BedrockAi21Jurassic2ChatOptions.builder().build());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -413,4 +413,19 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions {
|
||||
}
|
||||
}
|
||||
|
||||
public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2ChatOptions fromOptions) {
|
||||
return builder().withPrompt(fromOptions.getPrompt())
|
||||
.withNumResults(fromOptions.getNumResults())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withMinTokens(fromOptions.getMinTokens())
|
||||
.withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withTopK(fromOptions.getTopK())
|
||||
.withStopSequences(fromOptions.getStopSequences())
|
||||
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
|
||||
.withPresencePenalty(fromOptions.getPresencePenalty())
|
||||
.withCountPenalty(fromOptions.getCountPenalty())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
|
||||
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest;
|
||||
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse;
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
|
||||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
@@ -371,7 +372,7 @@ public class Ai21Jurassic2ChatBedrockApi extends
|
||||
/**
|
||||
* Ai21 Jurassic2 models version.
|
||||
*/
|
||||
public enum Ai21Jurassic2ChatModel {
|
||||
public enum Ai21Jurassic2ChatModel implements ModelDescription {
|
||||
|
||||
/**
|
||||
* ai21.j2-mid-v1
|
||||
@@ -395,6 +396,11 @@ public class Ai21Jurassic2ChatBedrockApi extends
|
||||
Ai21Jurassic2ChatModel(String value) {
|
||||
this.id = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.id;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -23,11 +23,11 @@ import org.springframework.ai.bedrock.MessageToPromptConverter;
|
||||
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
|
||||
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
|
||||
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
@@ -35,25 +35,25 @@ import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat
|
||||
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Llama chat
|
||||
* generative.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Wei Jiang
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {
|
||||
public class BedrockLlamaChatModel implements ChatModel, StreamingChatModel {
|
||||
|
||||
private final LlamaChatBedrockApi chatApi;
|
||||
|
||||
private final BedrockLlamaChatOptions defaultOptions;
|
||||
|
||||
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi) {
|
||||
public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi) {
|
||||
this(chatApi,
|
||||
BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
|
||||
}
|
||||
|
||||
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
|
||||
public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
|
||||
Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null");
|
||||
Assert.notNull(options, "BedrockLlamaChatOptions must not be null");
|
||||
|
||||
@@ -130,4 +130,9 @@ public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {
|
||||
return request;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return BedrockLlamaChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -109,4 +109,11 @@ public class BedrockLlamaChatOptions implements ChatOptions {
|
||||
throw new UnsupportedOperationException("Unsupported option: 'TopK'");
|
||||
}
|
||||
|
||||
public static BedrockLlamaChatOptions fromOptions(BedrockLlamaChatOptions fromOptions) {
|
||||
return builder().withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withMaxGenLen(fromOptions.getMaxGenLen())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import software.amazon.awssdk.regions.Region;
|
||||
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
|
||||
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
|
||||
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@@ -204,7 +205,7 @@ public class LlamaChatBedrockApi extends
|
||||
/**
|
||||
* Llama models version.
|
||||
*/
|
||||
public enum LlamaChatModel {
|
||||
public enum LlamaChatModel implements ModelDescription {
|
||||
|
||||
/**
|
||||
* meta.llama2-13b-chat-v1
|
||||
@@ -238,6 +239,11 @@ public class LlamaChatBedrockApi extends
|
||||
LlamaChatModel(String value) {
|
||||
this.id = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.id;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -24,11 +24,11 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
@@ -39,17 +39,17 @@ import org.springframework.util.Assert;
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class BedrockTitanChatClient implements ChatClient, StreamingChatClient {
|
||||
public class BedrockTitanChatModel implements ChatModel, StreamingChatModel {
|
||||
|
||||
private final TitanChatBedrockApi chatApi;
|
||||
|
||||
private final BedrockTitanChatOptions defaultOptions;
|
||||
|
||||
public BedrockTitanChatClient(TitanChatBedrockApi chatApi) {
|
||||
public BedrockTitanChatModel(TitanChatBedrockApi chatApi) {
|
||||
this(chatApi, BedrockTitanChatOptions.builder().withTemperature(0.8f).build());
|
||||
}
|
||||
|
||||
public BedrockTitanChatClient(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) {
|
||||
public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) {
|
||||
Assert.notNull(chatApi, "ChatApi must not be null");
|
||||
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
|
||||
this.chatApi = chatApi;
|
||||
@@ -146,4 +146,9 @@ public class BedrockTitanChatClient implements ChatClient, StreamingChatClient {
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return BedrockTitanChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -128,4 +128,12 @@ public class BedrockTitanChatOptions implements ChatOptions {
|
||||
throw new UnsupportedOperationException("Bedrock Titan Chat does not support the 'TopK' option.'");
|
||||
}
|
||||
|
||||
public static BedrockTitanChatOptions fromOptions(BedrockTitanChatOptions fromOptions) {
|
||||
return builder().withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withMaxTokenCount(fromOptions.getMaxTokenCount())
|
||||
.withStopSequences(fromOptions.getStopSequences())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingClient;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
@@ -34,7 +34,7 @@ import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* {@link org.springframework.ai.embedding.EmbeddingClient} implementation that uses the
|
||||
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
|
||||
* Bedrock Titan Embedding API. Titan Embedding supports text and image (encoded in
|
||||
* base64) inputs.
|
||||
*
|
||||
@@ -44,7 +44,7 @@ import org.springframework.util.Assert;
|
||||
* @author Wei Jiang
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
|
||||
public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@@ -61,7 +61,7 @@ public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
|
||||
*/
|
||||
private InputType inputType = InputType.TEXT;
|
||||
|
||||
public BedrockTitanEmbeddingClient(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
|
||||
public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
|
||||
this.embeddingApi = titanEmbeddingBedrockApi;
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
|
||||
* Titan Embedding API input types. Could be either text or image (encoded in base64).
|
||||
* @param inputType the input type to use.
|
||||
*/
|
||||
public BedrockTitanEmbeddingClient withInputType(InputType inputType) {
|
||||
public BedrockTitanEmbeddingModel withInputType(InputType inputType) {
|
||||
this.inputType = inputType;
|
||||
return this;
|
||||
}
|
||||
@@ -18,7 +18,7 @@ package org.springframework.ai.bedrock.titan;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
|
||||
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
|
||||
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatReq
|
||||
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse.CompletionReason;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk;
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
|
||||
/**
|
||||
* Java client for the Bedrock Titan chat model.
|
||||
@@ -265,7 +266,7 @@ public class TitanChatBedrockApi extends
|
||||
/**
|
||||
* Titan models version.
|
||||
*/
|
||||
public enum TitanChatModel {
|
||||
public enum TitanChatModel implements ModelDescription {
|
||||
|
||||
/**
|
||||
* amazon.titan-text-lite-v1
|
||||
@@ -294,6 +295,11 @@ public class TitanChatBedrockApi extends
|
||||
TitanChatModel(String value) {
|
||||
this.id = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.id;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -55,12 +55,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockAnthropicChatClientIT {
|
||||
class BedrockAnthropicChatModelIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropicChatClientIT.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropicChatModelIT.class);
|
||||
|
||||
@Autowired
|
||||
private BedrockAnthropicChatClient client;
|
||||
private BedrockAnthropicChatModel chatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -68,8 +68,8 @@ class BedrockAnthropicChatClientIT {
|
||||
@Test
|
||||
void multipleStreamAttempts() {
|
||||
|
||||
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
|
||||
String joke1 = joke1Stream.collectList()
|
||||
.block()
|
||||
@@ -101,7 +101,7 @@ class BedrockAnthropicChatClientIT {
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
@@ -119,7 +119,7 @@ class BedrockAnthropicChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors.", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.client.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputParser.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -137,7 +137,7 @@ class BedrockAnthropicChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -161,7 +161,7 @@ class BedrockAnthropicChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConvert.convert(generation.getOutput().getContent());
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
@@ -182,7 +182,7 @@ class BedrockAnthropicChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = client.stream(prompt)
|
||||
String generationTextFromStream = chatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -209,8 +209,8 @@ class BedrockAnthropicChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockAnthropicChatClient anthropicChatClient(AnthropicChatBedrockApi anthropicApi) {
|
||||
return new BedrockAnthropicChatClient(anthropicApi);
|
||||
public BedrockAnthropicChatModel anthropicChatModel(AnthropicChatBedrockApi anthropicApi) {
|
||||
return new BedrockAnthropicChatModel(anthropicApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -38,7 +38,7 @@ public class BedrockAnthropicCreateRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new BedrockAnthropicChatClient(anthropicChatApi,
|
||||
var client = new BedrockAnthropicChatModel(anthropicChatApi,
|
||||
AnthropicChatOptions.builder()
|
||||
.withTemperature(66.6f)
|
||||
.withTopK(66)
|
||||
|
||||
@@ -59,12 +59,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockAnthropic3ChatClientIT {
|
||||
class BedrockAnthropic3ChatModelIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropic3ChatClientIT.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropic3ChatModelIT.class);
|
||||
|
||||
@Autowired
|
||||
private BedrockAnthropic3ChatClient client;
|
||||
private BedrockAnthropic3ChatModel chatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -72,8 +72,8 @@ class BedrockAnthropic3ChatClientIT {
|
||||
@Test
|
||||
void multipleStreamAttempts() {
|
||||
|
||||
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
|
||||
String joke1 = joke1Stream.collectList()
|
||||
.block()
|
||||
@@ -105,7 +105,7 @@ class BedrockAnthropic3ChatClientIT {
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
@@ -123,7 +123,7 @@ class BedrockAnthropic3ChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors.", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.client.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -142,7 +142,7 @@ class BedrockAnthropic3ChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -166,7 +166,7 @@ class BedrockAnthropic3ChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
@@ -187,7 +187,7 @@ class BedrockAnthropic3ChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = client.stream(prompt)
|
||||
String generationTextFromStream = chatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -211,7 +211,7 @@ class BedrockAnthropic3ChatClientIT {
|
||||
var userMessage = new UserMessage("Explain what do you see o this picture?",
|
||||
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
|
||||
|
||||
var response = client.call(new Prompt(List.of(userMessage)));
|
||||
var response = chatModel.call(new Prompt(List.of(userMessage)));
|
||||
|
||||
logger.info(response.getResult().getOutput().getContent());
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
|
||||
@@ -228,8 +228,8 @@ class BedrockAnthropic3ChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockAnthropic3ChatClient anthropicChatClient(Anthropic3ChatBedrockApi anthropicApi) {
|
||||
return new BedrockAnthropic3ChatClient(anthropicApi);
|
||||
public BedrockAnthropic3ChatModel anthropicChatModel(Anthropic3ChatBedrockApi anthropicApi) {
|
||||
return new BedrockAnthropic3ChatModel(anthropicApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -37,7 +37,7 @@ public class BedrockAnthropic3CreateRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new BedrockAnthropic3ChatClient(anthropicChatApi,
|
||||
var client = new BedrockAnthropic3ChatModel(anthropicChatApi,
|
||||
Anthropic3ChatOptions.builder()
|
||||
.withTemperature(66.6f)
|
||||
.withTopK(66)
|
||||
|
||||
@@ -45,7 +45,7 @@ public class BedrockCohereChatCreateRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new BedrockCohereChatClient(chatApi,
|
||||
var client = new BedrockCohereChatModel(chatApi,
|
||||
BedrockCohereChatOptions.builder()
|
||||
.withTemperature(66.6f)
|
||||
.withTopK(66)
|
||||
|
||||
@@ -54,10 +54,10 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockCohereChatClientIT {
|
||||
class BedrockCohereChatModelIT {
|
||||
|
||||
@Autowired
|
||||
private BedrockCohereChatClient client;
|
||||
private BedrockCohereChatModel chatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -65,8 +65,8 @@ class BedrockCohereChatClientIT {
|
||||
@Test
|
||||
void multipleStreamAttempts() {
|
||||
|
||||
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
|
||||
String joke1 = joke1Stream.collectList()
|
||||
.block()
|
||||
@@ -98,7 +98,7 @@ class BedrockCohereChatClientIT {
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
|
||||
@@ -115,7 +115,7 @@ class BedrockCohereChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors.", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.client.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -134,7 +134,7 @@ class BedrockCohereChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -157,7 +157,7 @@ class BedrockCohereChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
@@ -178,7 +178,7 @@ class BedrockCohereChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = client.stream(prompt)
|
||||
String generationTextFromStream = chatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -205,8 +205,8 @@ class BedrockCohereChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockCohereChatClient cohereChatClient(CohereChatBedrockApi cohereApi) {
|
||||
return new BedrockCohereChatClient(cohereApi);
|
||||
public BedrockCohereChatModel cohereChatModel(CohereChatBedrockApi cohereApi) {
|
||||
return new BedrockCohereChatModel(cohereApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -39,24 +39,24 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockCohereEmbeddingClientIT {
|
||||
class BedrockCohereEmbeddingModelIT {
|
||||
|
||||
@Autowired
|
||||
private BedrockCohereEmbeddingClient embeddingClient;
|
||||
private BedrockCohereEmbeddingModel embeddingModel;
|
||||
|
||||
@Test
|
||||
void singleEmbedding() {
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World"));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
|
||||
}
|
||||
|
||||
@Test
|
||||
void batchEmbedding() {
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingClient
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingModel
|
||||
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(2);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
|
||||
@@ -64,13 +64,13 @@ class BedrockCohereEmbeddingClientIT {
|
||||
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
|
||||
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
|
||||
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
|
||||
}
|
||||
|
||||
@Test
|
||||
void embeddingWthOptions() {
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingClient
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingModel
|
||||
.call(new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
|
||||
BedrockCohereEmbeddingOptions.builder().withInputType(InputType.SEARCH_DOCUMENT).build()));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(2);
|
||||
@@ -79,7 +79,7 @@ class BedrockCohereEmbeddingClientIT {
|
||||
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
|
||||
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
|
||||
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
@@ -93,8 +93,8 @@ class BedrockCohereEmbeddingClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockCohereEmbeddingClient cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) {
|
||||
return new BedrockCohereEmbeddingClient(cohereEmbeddingApi);
|
||||
public BedrockCohereEmbeddingModel cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) {
|
||||
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -49,10 +49,10 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockAi21Jurassic2ChatClientIT {
|
||||
class BedrockAi21Jurassic2ChatModelIT {
|
||||
|
||||
@Autowired
|
||||
private BedrockAi21Jurassic2ChatClient client;
|
||||
private BedrockAi21Jurassic2ChatModel chatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -66,7 +66,7 @@ class BedrockAi21Jurassic2ChatClientIT {
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
@@ -83,7 +83,7 @@ class BedrockAi21Jurassic2ChatClientIT {
|
||||
UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄 ?");
|
||||
Prompt prompt = new Prompt(List.of(userMessage), options);
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).matches(content -> content.contains("😄"));
|
||||
}
|
||||
@@ -103,7 +103,7 @@ class BedrockAi21Jurassic2ChatClientIT {
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), options);
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).doesNotContain("😄");
|
||||
}
|
||||
@@ -120,7 +120,7 @@ class BedrockAi21Jurassic2ChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -135,7 +135,7 @@ class BedrockAi21Jurassic2ChatClientIT {
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("AI");
|
||||
}
|
||||
@@ -152,9 +152,9 @@ class BedrockAi21Jurassic2ChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockAi21Jurassic2ChatClient bedrockAi21Jurassic2ChatClient(
|
||||
public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel(
|
||||
Ai21Jurassic2ChatBedrockApi jurassic2ChatBedrockApi) {
|
||||
return new BedrockAi21Jurassic2ChatClient(jurassic2ChatBedrockApi,
|
||||
return new BedrockAi21Jurassic2ChatModel(jurassic2ChatBedrockApi,
|
||||
BedrockAi21Jurassic2ChatOptions.builder()
|
||||
.withTemperature(0.5f)
|
||||
.withMaxTokens(100)
|
||||
@@ -45,7 +45,7 @@ public class BedrockLlamaCreateRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new BedrockLlamaChatClient(api,
|
||||
var client = new BedrockLlamaChatModel(api,
|
||||
BedrockLlamaChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build());
|
||||
|
||||
var request = client.createRequest(new Prompt("Test message content"));
|
||||
|
||||
@@ -54,10 +54,10 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockLlamaChatClientIT {
|
||||
class BedrockLlamaChatModelIT {
|
||||
|
||||
@Autowired
|
||||
private BedrockLlamaChatClient client;
|
||||
private BedrockLlamaChatModel chatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -65,8 +65,8 @@ class BedrockLlamaChatClientIT {
|
||||
@Test
|
||||
void multipleStreamAttempts() {
|
||||
|
||||
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a Toy joke?")));
|
||||
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a Toy joke?")));
|
||||
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
|
||||
String joke1 = joke1Stream.collectList()
|
||||
.block()
|
||||
@@ -98,7 +98,7 @@ class BedrockLlamaChatClientIT {
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
@@ -116,7 +116,7 @@ class BedrockLlamaChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors.", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.client.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -134,7 +134,7 @@ class BedrockLlamaChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -158,7 +158,7 @@ class BedrockLlamaChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
@@ -179,7 +179,7 @@ class BedrockLlamaChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = client.stream(prompt)
|
||||
String generationTextFromStream = chatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -206,8 +206,8 @@ class BedrockLlamaChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockLlamaChatClient llamaChatClient(LlamaChatBedrockApi llamaApi) {
|
||||
return new BedrockLlamaChatClient(llamaApi,
|
||||
public BedrockLlamaChatModel llamaChatModel(LlamaChatBedrockApi llamaApi) {
|
||||
return new BedrockLlamaChatModel(llamaApi,
|
||||
BedrockLlamaChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build());
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ public class BedrockTitanChatCreateRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new BedrockTitanChatClient(api,
|
||||
var client = new BedrockTitanChatModel(api,
|
||||
BedrockTitanChatOptions.builder()
|
||||
.withTemperature(66.6f)
|
||||
.withTopP(0.66f)
|
||||
|
||||
@@ -26,7 +26,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
|
||||
import software.amazon.awssdk.regions.Region;
|
||||
|
||||
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
|
||||
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
|
||||
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
@@ -44,19 +44,19 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockTitanEmbeddingClientIT {
|
||||
class BedrockTitanEmbeddingModelIT {
|
||||
|
||||
@Autowired
|
||||
private BedrockTitanEmbeddingClient embeddingClient;
|
||||
private BedrockTitanEmbeddingModel embeddingModel;
|
||||
|
||||
@Test
|
||||
void singleEmbedding() {
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingClient.call(new EmbeddingRequest(List.of("Hello World"),
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"),
|
||||
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build()));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -65,12 +65,12 @@ class BedrockTitanEmbeddingClientIT {
|
||||
byte[] image = new DefaultResourceLoader().getResource("classpath:/spring_framework.png")
|
||||
.getContentAsByteArray();
|
||||
|
||||
EmbeddingResponse embeddingResponse = embeddingClient
|
||||
EmbeddingResponse embeddingResponse = embeddingModel
|
||||
.call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)),
|
||||
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build()));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
@@ -84,8 +84,8 @@ class BedrockTitanEmbeddingClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockTitanEmbeddingClient titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi) {
|
||||
return new BedrockTitanEmbeddingClient(titanEmbeddingApi);
|
||||
public BedrockTitanEmbeddingModel titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi) {
|
||||
return new BedrockTitanEmbeddingModel(titanEmbeddingApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -55,10 +55,10 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
|
||||
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
|
||||
class BedrockTitanChatClientIT {
|
||||
class BedrockTitanModelCalerlIT {
|
||||
|
||||
@Autowired
|
||||
private BedrockTitanChatClient client;
|
||||
private BedrockTitanChatModel chatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -66,8 +66,8 @@ class BedrockTitanChatClientIT {
|
||||
@Test
|
||||
void multipleStreamAttempts() {
|
||||
|
||||
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
|
||||
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
|
||||
|
||||
String joke1 = joke1Stream.collectList()
|
||||
.block()
|
||||
@@ -99,7 +99,7 @@ class BedrockTitanChatClientIT {
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
|
||||
@@ -117,7 +117,7 @@ class BedrockTitanChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors.", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.client.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -138,7 +138,7 @@ class BedrockTitanChatClientIT {
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -162,7 +162,7 @@ class BedrockTitanChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
@@ -184,7 +184,7 @@ class BedrockTitanChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = client.stream(prompt)
|
||||
String generationTextFromStream = chatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -211,8 +211,8 @@ class BedrockTitanChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public BedrockTitanChatClient titanChatClient(TitanChatBedrockApi titanApi) {
|
||||
return new BedrockTitanChatClient(titanApi);
|
||||
public BedrockTitanChatModel titanChatModel(TitanChatBedrockApi titanApi) {
|
||||
return new BedrockTitanChatModel(titanApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import java.util.Map;
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
|
||||
@@ -31,15 +31,17 @@ import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
|
||||
import org.springframework.ai.huggingface.model.GenerateParameters;
|
||||
import org.springframework.ai.huggingface.model.GenerateRequest;
|
||||
import org.springframework.ai.huggingface.model.GenerateResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
|
||||
/**
|
||||
* An implementation of {@link ChatClient} that interfaces with HuggingFace Inference
|
||||
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
|
||||
* Endpoints for text generation.
|
||||
*
|
||||
* @author Mark Pollack
|
||||
*/
|
||||
public class HuggingfaceChatClient implements ChatClient {
|
||||
public class HuggingfaceChatModel implements ChatModel {
|
||||
|
||||
/**
|
||||
* Token required for authenticating with the HuggingFace Inference API.
|
||||
@@ -68,11 +70,11 @@ public class HuggingfaceChatClient implements ChatClient {
|
||||
private int maxNewTokens = 1000;
|
||||
|
||||
/**
|
||||
* Constructs a new HuggingfaceChatClient with the specified API token and base path.
|
||||
* Constructs a new HuggingfaceChatModel with the specified API token and base path.
|
||||
* @param apiToken The API token for HuggingFace.
|
||||
* @param basePath The base path for API requests.
|
||||
*/
|
||||
public HuggingfaceChatClient(final String apiToken, String basePath) {
|
||||
public HuggingfaceChatModel(final String apiToken, String basePath) {
|
||||
this.apiToken = apiToken;
|
||||
this.apiClient.setBasePath(basePath);
|
||||
this.apiClient.addDefaultHeader("Authorization", "Bearer " + this.apiToken);
|
||||
@@ -120,4 +122,9 @@ public class HuggingfaceChatClient implements ChatClient {
|
||||
this.maxNewTokens = maxNewTokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return ChatOptionsBuilder.builder().build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -23,7 +23,7 @@ import org.springframework.util.StringUtils;
|
||||
public class HuggingfaceTestConfiguration {
|
||||
|
||||
@Bean
|
||||
public HuggingfaceChatClient huggingfaceChatClient() {
|
||||
public HuggingfaceChatModel huggingfaceChatModel() {
|
||||
String apiKey = System.getenv("HUGGINGFACE_API_KEY");
|
||||
if (!StringUtils.hasText(apiKey)) {
|
||||
throw new IllegalArgumentException(
|
||||
@@ -31,9 +31,9 @@ public class HuggingfaceTestConfiguration {
|
||||
}
|
||||
// Created aws-mistral-7b-instruct-v0-1-805 via
|
||||
// https://ui.endpoints.huggingface.co/
|
||||
HuggingfaceChatClient huggingfaceChatClient = new HuggingfaceChatClient(apiKey,
|
||||
HuggingfaceChatModel huggingfaceChatModel = new HuggingfaceChatModel(apiKey,
|
||||
"https://f6hg7b3cvlmntp5i.us-east-1.aws.endpoints.huggingface.cloud");
|
||||
return huggingfaceChatClient;
|
||||
return huggingfaceChatModel;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.huggingface.HuggingfaceChatClient;
|
||||
import org.springframework.ai.huggingface.HuggingfaceChatModel;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -33,7 +33,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
public class ClientIT {
|
||||
|
||||
@Autowired
|
||||
protected HuggingfaceChatClient huggingfaceChatClient;
|
||||
protected HuggingfaceChatModel huggingfaceChatModel;
|
||||
|
||||
@Test
|
||||
void helloWorldCompletion() {
|
||||
@@ -46,7 +46,7 @@ public class ClientIT {
|
||||
[/INST]
|
||||
""";
|
||||
Prompt prompt = new Prompt(mistral7bInstruct);
|
||||
ChatResponse chatResponse = huggingfaceChatClient.call(prompt);
|
||||
ChatResponse chatResponse = huggingfaceChatModel.call(prompt);
|
||||
assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty();
|
||||
String expectedResponse = """
|
||||
```json
|
||||
|
||||
@@ -15,30 +15,6 @@
|
||||
*/
|
||||
package org.springframework.ai.minimax;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.*;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.Base64;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -47,21 +23,51 @@ import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionFinishReason;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.FunctionTool;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* {@link ChatClient} and {@link StreamingChatClient} implementation for
|
||||
* {@literal MiniMax} backed by {@link MiniMaxApi}.
|
||||
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax}
|
||||
* backed by {@link MiniMaxApi}.
|
||||
*
|
||||
* @author Geng Rong
|
||||
* @see ChatClient
|
||||
* @see StreamingChatClient
|
||||
* @see ChatModel
|
||||
* @see StreamingChatModel
|
||||
* @see MiniMaxApi
|
||||
* @since 1.0.0 M1
|
||||
*/
|
||||
public class MiniMaxChatClient extends
|
||||
AbstractFunctionCallSupport<ChatCompletionMessage, ChatCompletionRequest, ResponseEntity<ChatCompletion>>
|
||||
implements ChatClient, StreamingChatClient {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatClient.class);
|
||||
public class MiniMaxChatModel extends
|
||||
AbstractFunctionCallSupport<MiniMaxApi.ChatCompletionMessage, MiniMaxApi.ChatCompletionRequest, ResponseEntity<MiniMaxApi.ChatCompletion>>
|
||||
implements ChatModel, StreamingChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatModel.class);
|
||||
|
||||
/**
|
||||
* The default options used for the chat completion requests.
|
||||
@@ -79,35 +85,35 @@ public class MiniMaxChatClient extends
|
||||
private final MiniMaxApi miniMaxApi;
|
||||
|
||||
/**
|
||||
* Creates an instance of the MiniMaxChatClient.
|
||||
* Creates an instance of the MiniMaxChatModel.
|
||||
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
|
||||
* MiniMax Chat API.
|
||||
* MiniMax Chat API.
|
||||
* @throws IllegalArgumentException if MiniMaxApi is null
|
||||
*/
|
||||
public MiniMaxChatClient(MiniMaxApi miniMaxApi) {
|
||||
public MiniMaxChatModel(MiniMaxApi miniMaxApi) {
|
||||
this(miniMaxApi,
|
||||
MiniMaxChatOptions.builder().withModel(MiniMaxApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes an instance of the MiniMaxChatClient.
|
||||
* Initializes an instance of the MiniMaxChatModel.
|
||||
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
|
||||
* MiniMax Chat API.
|
||||
* @param options The MiniMaxChatOptions to configure the chat client.
|
||||
* @param options The MiniMaxChatOptions to configure the chat model.
|
||||
*/
|
||||
public MiniMaxChatClient(MiniMaxApi miniMaxApi, MiniMaxChatOptions options) {
|
||||
public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options) {
|
||||
this(miniMaxApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MiniMaxChatClient.
|
||||
* Initializes a new instance of the MiniMaxChatModel.
|
||||
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
|
||||
* MiniMax Chat API.
|
||||
* @param options The MiniMaxChatOptions to configure the chat client.
|
||||
* @param options The MiniMaxChatOptions to configure the chat model.
|
||||
* @param functionCallbackContext The function callback context.
|
||||
* @param retryTemplate The retry template.
|
||||
*/
|
||||
public MiniMaxChatClient(MiniMaxApi miniMaxApi, MiniMaxChatOptions options,
|
||||
public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options,
|
||||
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
|
||||
super(functionCallbackContext);
|
||||
Assert.notNull(miniMaxApi, "MiniMaxApi must not be null");
|
||||
@@ -279,7 +285,7 @@ public class MiniMaxChatClient extends
|
||||
return request;
|
||||
}
|
||||
|
||||
private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
|
||||
private List<MiniMaxApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
|
||||
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
|
||||
var function = new FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(),
|
||||
functionCallback.getInputTypeSchema());
|
||||
@@ -358,4 +364,9 @@ public class MiniMaxChatClient extends
|
||||
&& choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return MiniMaxChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -114,10 +114,10 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions {
|
||||
private @JsonProperty("tool_choice") String toolChoice;
|
||||
|
||||
/**
|
||||
* MiniMax Tool Function Callbacks to register with the ChatClient.
|
||||
* MiniMax Tool Function Callbacks to register with the ChatModel.
|
||||
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
|
||||
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
|
||||
* from the registry to be used by the ChatClient chat completion requests.
|
||||
* from the registry to be used by the ChatModel chat completion requests.
|
||||
*/
|
||||
@NestedConfigurationProperty
|
||||
@JsonIgnore
|
||||
@@ -467,4 +467,22 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions {
|
||||
return true;
|
||||
}
|
||||
|
||||
public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) {
|
||||
return builder().withModel(fromOptions.getModel())
|
||||
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withN(fromOptions.getN())
|
||||
.withPresencePenalty(fromOptions.getPresencePenalty())
|
||||
.withResponseFormat(fromOptions.getResponseFormat())
|
||||
.withSeed(fromOptions.getSeed())
|
||||
.withStop(fromOptions.getStop())
|
||||
.withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withTools(fromOptions.getTools())
|
||||
.withToolChoice(fromOptions.getToolChoice())
|
||||
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
|
||||
.withFunctions(fromOptions.getFunctions())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingClient;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
@@ -40,9 +40,9 @@ import java.util.List;
|
||||
* @author Geng Rong
|
||||
* @since 1.0.0 M1
|
||||
*/
|
||||
public class MiniMaxEmbeddingClient extends AbstractEmbeddingClient {
|
||||
public class MiniMaxEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MiniMaxEmbeddingClient.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(MiniMaxEmbeddingModel.class);
|
||||
|
||||
private final MiniMaxEmbeddingOptions defaultOptions;
|
||||
|
||||
@@ -53,43 +53,43 @@ public class MiniMaxEmbeddingClient extends AbstractEmbeddingClient {
|
||||
private final MetadataMode metadataMode;
|
||||
|
||||
/**
|
||||
* Constructor for the MiniMaxEmbeddingClient class.
|
||||
* Constructor for the MiniMaxEmbeddingModel class.
|
||||
* @param miniMaxApi The MiniMaxApi instance to use for making API requests.
|
||||
*/
|
||||
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi) {
|
||||
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi) {
|
||||
this(miniMaxApi, MetadataMode.EMBED);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MiniMaxEmbeddingClient class.
|
||||
* Initializes a new instance of the MiniMaxEmbeddingModel class.
|
||||
* @param miniMaxApi The MiniMaxApi instance to use for making API requests.
|
||||
* @param metadataMode The mode for generating metadata.
|
||||
*/
|
||||
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi, MetadataMode metadataMode) {
|
||||
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode) {
|
||||
this(miniMaxApi, metadataMode,
|
||||
MiniMaxEmbeddingOptions.builder().withModel(MiniMaxApi.DEFAULT_EMBEDDING_MODEL).build(),
|
||||
RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MiniMaxEmbeddingClient class.
|
||||
* Initializes a new instance of the MiniMaxEmbeddingModel class.
|
||||
* @param miniMaxApi The MiniMaxApi instance to use for making API requests.
|
||||
* @param metadataMode The mode for generating metadata.
|
||||
* @param miniMaxEmbeddingOptions The options for MiniMax embedding.
|
||||
*/
|
||||
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi, MetadataMode metadataMode,
|
||||
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode,
|
||||
MiniMaxEmbeddingOptions miniMaxEmbeddingOptions) {
|
||||
this(miniMaxApi, metadataMode, miniMaxEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the MiniMaxEmbeddingClient class.
|
||||
* Initializes a new instance of the MiniMaxEmbeddingModel class.
|
||||
* @param miniMaxApi - The MiniMaxApi instance to use for making API requests.
|
||||
* @param metadataMode - The mode for generating metadata.
|
||||
* @param options - The options for MiniMax embedding.
|
||||
* @param retryTemplate - The RetryTemplate for retrying failed API requests.
|
||||
*/
|
||||
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi, MetadataMode metadataMode, MiniMaxEmbeddingOptions options,
|
||||
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode, MiniMaxEmbeddingOptions options,
|
||||
RetryTemplate retryTemplate) {
|
||||
Assert.notNull(miniMaxApi, "MiniMaxApi must not be null");
|
||||
Assert.notNull(metadataMode, "metadataMode must not be null");
|
||||
@@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonInclude.Include;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.boot.context.properties.bind.ConstructorBinding;
|
||||
@@ -111,7 +113,7 @@ public class MiniMaxApi {
|
||||
* MiniMax Chat Completion Models:
|
||||
* <a href="https://www.minimaxi.com/document/algorithm-concept">MiniMax Model</a>.
|
||||
*/
|
||||
public enum ChatModel {
|
||||
public enum ChatModel implements ModelDescription {
|
||||
ABAB_6_Chat("abab6-chat"),
|
||||
ABAB_5_5_Chat("abab5.5-chat"),
|
||||
ABAB_5_5_S_Chat("abab5.5s-chat");
|
||||
@@ -125,6 +127,11 @@ public class MiniMaxApi {
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -33,7 +33,7 @@ public class ChatCompletionRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new MiniMaxChatClient(new MiniMaxApi("TEST"),
|
||||
var client = new MiniMaxChatModel(new MiniMaxApi("TEST"),
|
||||
MiniMaxChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
|
||||
|
||||
var request = client.createRequest(new Prompt("Test message content"), false);
|
||||
@@ -59,7 +59,7 @@ public class ChatCompletionRequestTests {
|
||||
|
||||
final String TOOL_FUNCTION_NAME = "CurrentWeather";
|
||||
|
||||
var client = new MiniMaxChatClient(new MiniMaxApi("TEST"),
|
||||
var client = new MiniMaxChatModel(new MiniMaxApi("TEST"),
|
||||
MiniMaxChatOptions.builder().withModel("DEFAULT_MODEL").build());
|
||||
|
||||
var request = client.createRequest(new Prompt("Test message content",
|
||||
@@ -89,7 +89,7 @@ public class ChatCompletionRequestTests {
|
||||
|
||||
final String TOOL_FUNCTION_NAME = "CurrentWeather";
|
||||
|
||||
var client = new MiniMaxChatClient(new MiniMaxApi("TEST"),
|
||||
var client = new MiniMaxChatModel(new MiniMaxApi("TEST"),
|
||||
MiniMaxChatOptions.builder()
|
||||
.withModel("DEFAULT_MODEL")
|
||||
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
*/
|
||||
package org.springframework.ai.minimax;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingClient;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -42,13 +42,13 @@ public class MiniMaxTestConfiguration {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MiniMaxChatClient miniMaxChatClient(MiniMaxApi api) {
|
||||
return new MiniMaxChatClient(api);
|
||||
public MiniMaxChatModel miniMaxChatModel(MiniMaxApi api) {
|
||||
return new MiniMaxChatModel(api);
|
||||
}
|
||||
|
||||
@Bean
|
||||
public EmbeddingClient miniMaxEmbeddingClient(MiniMaxApi api) {
|
||||
return new MiniMaxEmbeddingClient(api);
|
||||
public EmbeddingModel miniMaxEmbeddingModel(MiniMaxApi api) {
|
||||
return new MiniMaxEmbeddingModel(api);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,9 +22,9 @@ import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.minimax.MiniMaxChatClient;
|
||||
import org.springframework.ai.minimax.MiniMaxChatModel;
|
||||
import org.springframework.ai.minimax.MiniMaxChatOptions;
|
||||
import org.springframework.ai.minimax.MiniMaxEmbeddingClient;
|
||||
import org.springframework.ai.minimax.MiniMaxEmbeddingModel;
|
||||
import org.springframework.ai.minimax.MiniMaxEmbeddingOptions;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion;
|
||||
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk;
|
||||
@@ -83,9 +83,9 @@ public class MiniMaxRetryTests {
|
||||
|
||||
private @Mock MiniMaxApi miniMaxApi;
|
||||
|
||||
private MiniMaxChatClient chatClient;
|
||||
private MiniMaxChatModel chatModel;
|
||||
|
||||
private MiniMaxEmbeddingClient embeddingClient;
|
||||
private MiniMaxEmbeddingModel embeddingModel;
|
||||
|
||||
@BeforeEach
|
||||
public void beforeEach() {
|
||||
@@ -93,8 +93,8 @@ public class MiniMaxRetryTests {
|
||||
retryListener = new TestRetryListener();
|
||||
retryTemplate.registerListener(retryListener);
|
||||
|
||||
chatClient = new MiniMaxChatClient(miniMaxApi, MiniMaxChatOptions.builder().build(), null, retryTemplate);
|
||||
embeddingClient = new MiniMaxEmbeddingClient(miniMaxApi, MetadataMode.EMBED,
|
||||
chatModel = new MiniMaxChatModel(miniMaxApi, MiniMaxChatOptions.builder().build(), null, retryTemplate);
|
||||
embeddingModel = new MiniMaxEmbeddingModel(miniMaxApi, MetadataMode.EMBED,
|
||||
MiniMaxEmbeddingOptions.builder().build(), retryTemplate);
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ public class MiniMaxRetryTests {
|
||||
.thenThrow(new TransientAiException("Transient Error 2"))
|
||||
.thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
|
||||
|
||||
var result = chatClient.call(new Prompt("text"));
|
||||
var result = chatModel.call(new Prompt("text"));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.getResult().getOutput().getContent()).isSameAs("Response");
|
||||
@@ -123,7 +123,7 @@ public class MiniMaxRetryTests {
|
||||
public void miniMaxChatNonTransientError() {
|
||||
when(miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
|
||||
.thenThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text")));
|
||||
assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text")));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -139,7 +139,7 @@ public class MiniMaxRetryTests {
|
||||
.thenThrow(new TransientAiException("Transient Error 2"))
|
||||
.thenReturn(Flux.just(expectedChatCompletion));
|
||||
|
||||
var result = chatClient.stream(new Prompt("text"));
|
||||
var result = chatModel.stream(new Prompt("text"));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response");
|
||||
@@ -151,7 +151,7 @@ public class MiniMaxRetryTests {
|
||||
public void miniMaxChatStreamNonTransientError() {
|
||||
when(miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
|
||||
.thenThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")));
|
||||
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -164,7 +164,7 @@ public class MiniMaxRetryTests {
|
||||
.thenThrow(new TransientAiException("Transient Error 2"))
|
||||
.thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));
|
||||
|
||||
var result = embeddingClient
|
||||
var result = embeddingModel
|
||||
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
@@ -177,7 +177,7 @@ public class MiniMaxRetryTests {
|
||||
public void miniMaxEmbeddingNonTransientError() {
|
||||
when(miniMaxApi.embeddings(isA(EmbeddingRequest.class)))
|
||||
.thenThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> embeddingClient
|
||||
assertThrows(RuntimeException.class, () -> embeddingModel
|
||||
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
|
||||
}
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ package org.springframework.ai.mistralai;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
@@ -55,9 +55,9 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
* @author Grogdunn
|
||||
* @since 0.8.1
|
||||
*/
|
||||
public class MistralAiChatClient extends
|
||||
public class MistralAiChatModel extends
|
||||
AbstractFunctionCallSupport<MistralAiApi.ChatCompletionMessage, MistralAiApi.ChatCompletionRequest, ResponseEntity<MistralAiApi.ChatCompletion>>
|
||||
implements ChatClient, StreamingChatClient {
|
||||
implements ChatModel, StreamingChatModel {
|
||||
|
||||
private final Logger log = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@@ -73,7 +73,7 @@ public class MistralAiChatClient extends
|
||||
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
public MistralAiChatClient(MistralAiApi mistralAiApi) {
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi) {
|
||||
this(mistralAiApi,
|
||||
MistralAiChatOptions.builder()
|
||||
.withTemperature(0.7f)
|
||||
@@ -83,11 +83,11 @@ public class MistralAiChatClient extends
|
||||
.build());
|
||||
}
|
||||
|
||||
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
|
||||
this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options,
|
||||
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
|
||||
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
|
||||
super(functionCallbackContext);
|
||||
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
|
||||
@@ -324,4 +324,9 @@ public class MistralAiChatClient extends
|
||||
return !CollectionUtils.isEmpty(choices.get(0).message().toolCalls());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return MistralAiChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -101,11 +101,11 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
|
||||
private @JsonProperty("tool_choice") ToolChoice toolChoice;
|
||||
|
||||
/**
|
||||
* MistralAI Tool Function Callbacks to register with the ChatClient. For Prompt
|
||||
* MistralAI Tool Function Callbacks to register with the ChatModel. For Prompt
|
||||
* Options the functionCallbacks are automatically enabled for the duration of the
|
||||
* prompt execution. For Default Options the functionCallbacks are registered but
|
||||
* disabled by default. Use the enableFunctions to set the functions from the registry
|
||||
* to be used by the ChatClient chat completion requests.
|
||||
* to be used by the ChatModel chat completion requests.
|
||||
*/
|
||||
@NestedConfigurationProperty
|
||||
@JsonIgnore
|
||||
@@ -139,7 +139,7 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withMaxToken(Integer maxTokens) {
|
||||
public Builder withMaxTokens(Integer maxTokens) {
|
||||
this.options.setMaxTokens(maxTokens);
|
||||
return this;
|
||||
}
|
||||
@@ -309,4 +309,19 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
|
||||
this.functions = functions;
|
||||
}
|
||||
|
||||
public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) {
|
||||
return builder().withModel(fromOptions.getModel())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withSafePrompt(fromOptions.getSafePrompt())
|
||||
.withRandomSeed(fromOptions.getRandomSeed())
|
||||
.withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withResponseFormat(fromOptions.getResponseFormat())
|
||||
.withTools(fromOptions.getTools())
|
||||
.withToolChoice(fromOptions.getToolChoice())
|
||||
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
|
||||
.withFunctions(fromOptions.getFunctions())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingClient;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
@@ -38,7 +38,7 @@ import org.springframework.util.Assert;
|
||||
* @author Ricken Bazolo
|
||||
* @since 0.8.1
|
||||
*/
|
||||
public class MistralAiEmbeddingClient extends AbstractEmbeddingClient {
|
||||
public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private final Logger log = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@@ -50,21 +50,21 @@ public class MistralAiEmbeddingClient extends AbstractEmbeddingClient {
|
||||
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) {
|
||||
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
|
||||
this(mistralAiApi, MetadataMode.EMBED);
|
||||
}
|
||||
|
||||
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
|
||||
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
|
||||
this(mistralAiApi, metadataMode,
|
||||
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
|
||||
RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) {
|
||||
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) {
|
||||
this(mistralAiApi, MetadataMode.EMBED, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode,
|
||||
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode,
|
||||
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate) {
|
||||
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
|
||||
Assert.notNull(metadataMode, "metadataMode must not be null");
|
||||
@@ -27,6 +27,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.boot.context.properties.bind.ConstructorBinding;
|
||||
@@ -706,7 +707,7 @@ public class MistralAiApi {
|
||||
* <li><b>LARGE</b> - mistral-large-latest (aka mistral-large-2402)</li>
|
||||
* </ul>
|
||||
*/
|
||||
public enum ChatModel {
|
||||
public enum ChatModel implements ModelDescription {
|
||||
|
||||
// @formatter:off
|
||||
TINY("open-mistral-7b"),
|
||||
@@ -726,6 +727,11 @@ public class MistralAiApi {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.value;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -32,12 +32,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
|
||||
public class MistralAiChatCompletionRequestTest {
|
||||
|
||||
MistralAiChatClient chatClient = new MistralAiChatClient(new MistralAiApi("test"));
|
||||
MistralAiChatModel chatModel = new MistralAiChatModel(new MistralAiApi("test"));
|
||||
|
||||
@Test
|
||||
void chatCompletionDefaultRequestTest() {
|
||||
|
||||
var request = chatClient.createRequest(new Prompt("test content"), false);
|
||||
var request = chatModel.createRequest(new Prompt("test content"), false);
|
||||
|
||||
assertThat(request.messages()).hasSize(1);
|
||||
assertThat(request.topP()).isEqualTo(1);
|
||||
@@ -52,7 +52,7 @@ public class MistralAiChatCompletionRequestTest {
|
||||
|
||||
var options = MistralAiChatOptions.builder().withTemperature(0.5f).withTopP(0.8f).build();
|
||||
|
||||
var request = chatClient.createRequest(new Prompt("test content", options), true);
|
||||
var request = chatModel.createRequest(new Prompt("test content", options), true);
|
||||
|
||||
assertThat(request.messages().size()).isEqualTo(1);
|
||||
assertThat(request.topP()).isEqualTo(0.8f);
|
||||
|
||||
@@ -27,10 +27,10 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
@@ -56,15 +56,15 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
*/
|
||||
@SpringBootTest(classes = MistralAiTestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
|
||||
class MistralAiChatClientIT {
|
||||
class MistralAiChatModelIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MistralAiChatClientIT.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(MistralAiChatModelIT.class);
|
||||
|
||||
@Autowired
|
||||
protected ChatClient chatClient;
|
||||
protected ChatModel chatModel;
|
||||
|
||||
@Autowired
|
||||
protected StreamingChatClient streamingChatClient;
|
||||
protected StreamingChatModel streamingChatModel;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
@@ -90,7 +90,7 @@ class MistralAiChatClientIT {
|
||||
// NOTE: Mistral expects the system message to be before the user message or will
|
||||
// fail with 400 error.
|
||||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
|
||||
ChatResponse response = chatClient.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
|
||||
}
|
||||
@@ -108,7 +108,7 @@ class MistralAiChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatClient.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -126,7 +126,7 @@ class MistralAiChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -148,7 +148,7 @@ class MistralAiChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
logger.info("" + actorsFilms);
|
||||
@@ -169,7 +169,7 @@ class MistralAiChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = streamingChatClient.stream(prompt)
|
||||
String generationTextFromStream = streamingChatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -201,7 +201,7 @@ class MistralAiChatClientIT {
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
|
||||
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
@@ -224,7 +224,7 @@ class MistralAiChatClientIT {
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
Flux<ChatResponse> response = streamingChatClient.stream(new Prompt(messages, promptOptions));
|
||||
Flux<ChatResponse> response = streamingChatModel.stream(new Prompt(messages, promptOptions));
|
||||
|
||||
String content = response.collectList()
|
||||
.block()
|
||||
@@ -31,25 +31,25 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
class MistralAiEmbeddingIT {
|
||||
|
||||
@Autowired
|
||||
private MistralAiEmbeddingClient mistralAiEmbeddingClient;
|
||||
private MistralAiEmbeddingModel mistralAiEmbeddingModel;
|
||||
|
||||
@Test
|
||||
void defaultEmbedding() {
|
||||
assertThat(mistralAiEmbeddingClient).isNotNull();
|
||||
var embeddingResponse = mistralAiEmbeddingClient.embedForResponse(List.of("Hello World"));
|
||||
assertThat(mistralAiEmbeddingModel).isNotNull();
|
||||
var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "mistral-embed");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 4);
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 4);
|
||||
assertThat(mistralAiEmbeddingClient.dimensions()).isEqualTo(1024);
|
||||
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
|
||||
}
|
||||
|
||||
@Test
|
||||
void embeddingTest() {
|
||||
assertThat(mistralAiEmbeddingClient).isNotNull();
|
||||
var embeddingResponse = mistralAiEmbeddingClient.call(new EmbeddingRequest(
|
||||
assertThat(mistralAiEmbeddingModel).isNotNull();
|
||||
var embeddingResponse = mistralAiEmbeddingModel.call(new EmbeddingRequest(
|
||||
List.of("Hello World", "World is big"),
|
||||
MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build()));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(2);
|
||||
@@ -58,7 +58,7 @@ class MistralAiEmbeddingIT {
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "mistral-embed");
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 9);
|
||||
assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 9);
|
||||
assertThat(mistralAiEmbeddingClient.dimensions()).isEqualTo(1024);
|
||||
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -82,9 +82,9 @@ public class MistralAiRetryTests {
|
||||
|
||||
private @Mock MistralAiApi mistralAiApi;
|
||||
|
||||
private MistralAiChatClient chatClient;
|
||||
private MistralAiChatModel chatModel;
|
||||
|
||||
private MistralAiEmbeddingClient embeddingClient;
|
||||
private MistralAiEmbeddingModel embeddingModel;
|
||||
|
||||
@BeforeEach
|
||||
public void beforeEach() {
|
||||
@@ -92,7 +92,7 @@ public class MistralAiRetryTests {
|
||||
retryListener = new TestRetryListener();
|
||||
retryTemplate.registerListener(retryListener);
|
||||
|
||||
chatClient = new MistralAiChatClient(mistralAiApi,
|
||||
chatModel = new MistralAiChatModel(mistralAiApi,
|
||||
MistralAiChatOptions.builder()
|
||||
.withTemperature(0.7f)
|
||||
.withTopP(1f)
|
||||
@@ -100,7 +100,7 @@ public class MistralAiRetryTests {
|
||||
.withModel(MistralAiApi.ChatModel.TINY.getValue())
|
||||
.build(),
|
||||
null, retryTemplate);
|
||||
embeddingClient = new MistralAiEmbeddingClient(mistralAiApi, MetadataMode.EMBED,
|
||||
embeddingModel = new MistralAiEmbeddingModel(mistralAiApi, MetadataMode.EMBED,
|
||||
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
|
||||
retryTemplate);
|
||||
}
|
||||
@@ -118,7 +118,7 @@ public class MistralAiRetryTests {
|
||||
.thenThrow(new TransientAiException("Transient Error 2"))
|
||||
.thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
|
||||
|
||||
var result = chatClient.call(new Prompt("text"));
|
||||
var result = chatModel.call(new Prompt("text"));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.getResult().getOutput().getContent()).isSameAs("Response");
|
||||
@@ -130,7 +130,7 @@ public class MistralAiRetryTests {
|
||||
public void mistralAiChatNonTransientError() {
|
||||
when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
|
||||
.thenThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text")));
|
||||
assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text")));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -146,7 +146,7 @@ public class MistralAiRetryTests {
|
||||
.thenThrow(new TransientAiException("Transient Error 2"))
|
||||
.thenReturn(Flux.just(expectedChatCompletion));
|
||||
|
||||
var result = chatClient.stream(new Prompt("text"));
|
||||
var result = chatModel.stream(new Prompt("text"));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response");
|
||||
@@ -158,7 +158,7 @@ public class MistralAiRetryTests {
|
||||
public void mistralAiChatStreamNonTransientError() {
|
||||
when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
|
||||
.thenThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")));
|
||||
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -172,7 +172,7 @@ public class MistralAiRetryTests {
|
||||
.thenThrow(new TransientAiException("Transient Error 2"))
|
||||
.thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));
|
||||
|
||||
var result = embeddingClient
|
||||
var result = embeddingModel
|
||||
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
@@ -185,7 +185,7 @@ public class MistralAiRetryTests {
|
||||
public void mistralAiEmbeddingNonTransientError() {
|
||||
when(mistralAiApi.embeddings(isA(EmbeddingRequest.class)))
|
||||
.thenThrow(new RuntimeException("Non Transient Error"));
|
||||
assertThrows(RuntimeException.class, () -> embeddingClient
|
||||
assertThrows(RuntimeException.class, () -> embeddingModel
|
||||
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
*/
|
||||
package org.springframework.ai.mistralai;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingClient;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -35,14 +35,14 @@ public class MistralAiTestConfiguration {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public EmbeddingClient mistralAiEmbeddingClient(MistralAiApi api) {
|
||||
return new MistralAiEmbeddingClient(api,
|
||||
public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
|
||||
return new MistralAiEmbeddingModel(api,
|
||||
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build());
|
||||
}
|
||||
|
||||
@Bean
|
||||
public MistralAiChatClient mistralAiChatClient(MistralAiApi mistralAiApi) {
|
||||
return new MistralAiChatClient(mistralAiApi,
|
||||
public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) {
|
||||
return new MistralAiChatModel(mistralAiApi,
|
||||
MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.MIXTRAL.getValue()).build());
|
||||
}
|
||||
|
||||
|
||||
@@ -18,13 +18,13 @@ package org.springframework.ai.ollama;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.ollama.metadata.OllamaChatResponseMetadata;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
@@ -39,7 +39,7 @@ import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* {@link ChatClient} implementation for {@literal Ollama}.
|
||||
* {@link ChatModel} implementation for {@literal Ollama}.
|
||||
*
|
||||
* Ollama allows developers to run large language models and generate embeddings locally.
|
||||
* It supports open-source models available on [Ollama AI
|
||||
@@ -52,7 +52,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class OllamaChatClient implements ChatClient, StreamingChatClient {
|
||||
public class OllamaChatModel implements ChatModel, StreamingChatModel {
|
||||
|
||||
/**
|
||||
* Low-level Ollama API library.
|
||||
@@ -64,11 +64,11 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
|
||||
*/
|
||||
private OllamaOptions defaultOptions;
|
||||
|
||||
public OllamaChatClient(OllamaApi chatApi) {
|
||||
public OllamaChatModel(OllamaApi chatApi) {
|
||||
this(chatApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
|
||||
}
|
||||
|
||||
public OllamaChatClient(OllamaApi chatApi, OllamaOptions defaultOptions) {
|
||||
public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions) {
|
||||
Assert.notNull(chatApi, "OllamaApi must not be null");
|
||||
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
|
||||
this.chatApi = chatApi;
|
||||
@@ -79,7 +79,7 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
|
||||
* @deprecated Use {@link OllamaOptions#setModel} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public OllamaChatClient withModel(String model) {
|
||||
public OllamaChatModel withModel(String model) {
|
||||
this.defaultOptions.setModel(model);
|
||||
return this;
|
||||
}
|
||||
@@ -88,7 +88,7 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
|
||||
* @deprecated Use {@link OllamaOptions} constructor instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public OllamaChatClient withDefaultOptions(OllamaOptions options) {
|
||||
public OllamaChatModel withDefaultOptions(OllamaOptions options) {
|
||||
this.defaultOptions = options;
|
||||
return this;
|
||||
}
|
||||
@@ -205,4 +205,9 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return OllamaOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -23,9 +23,9 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingClient;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingClient;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
@@ -36,7 +36,7 @@ import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* {@link EmbeddingClient} implementation for {@literal Ollama}.
|
||||
* {@link EmbeddingModel} implementation for {@literal Ollama}.
|
||||
*
|
||||
* Ollama allows developers to run large language models and generate embeddings locally.
|
||||
* It supports open-source models available on [Ollama AI
|
||||
@@ -51,7 +51,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
|
||||
public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@@ -62,11 +62,11 @@ public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
|
||||
*/
|
||||
private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
|
||||
|
||||
public OllamaEmbeddingClient(OllamaApi ollamaApi) {
|
||||
public OllamaEmbeddingModel(OllamaApi ollamaApi) {
|
||||
this.ollamaApi = ollamaApi;
|
||||
}
|
||||
|
||||
public OllamaEmbeddingClient(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
|
||||
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
|
||||
this.ollamaApi = ollamaApi;
|
||||
this.defaultOptions = defaultOptions;
|
||||
}
|
||||
@@ -75,7 +75,7 @@ public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
|
||||
* @deprecated Use {@link OllamaOptions#setModel} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public OllamaEmbeddingClient withModel(String model) {
|
||||
public OllamaEmbeddingModel withModel(String model) {
|
||||
this.defaultOptions.setModel(model);
|
||||
return this;
|
||||
}
|
||||
@@ -84,7 +84,7 @@ public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
|
||||
* @deprecated Use {@link OllamaOptions} constructor instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public OllamaEmbeddingClient withDefaultOptions(OllamaOptions options) {
|
||||
public OllamaEmbeddingModel withDefaultOptions(OllamaOptions options) {
|
||||
this.defaultOptions = options;
|
||||
return this;
|
||||
}
|
||||
@@ -15,13 +15,15 @@
|
||||
*/
|
||||
package org.springframework.ai.ollama.api;
|
||||
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
|
||||
/**
|
||||
* Helper class for common Ollama models.
|
||||
*
|
||||
* @author Siarhei Blashuk
|
||||
* @since 0.8.1
|
||||
*/
|
||||
public enum OllamaModel {
|
||||
public enum OllamaModel implements ModelDescription {
|
||||
|
||||
/**
|
||||
* Llama 2 is a collection of language models ranging from 7B to 70B parameters.
|
||||
@@ -99,4 +101,9 @@ public enum OllamaModel {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.id;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -714,6 +714,43 @@ public class OllamaOptions implements ChatOptions, EmbeddingOptions {
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
}
|
||||
|
||||
public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
|
||||
return new OllamaOptions()
|
||||
.withModel(fromOptions.getModel())
|
||||
.withFormat(fromOptions.getFormat())
|
||||
.withKeepAlive(fromOptions.getKeepAlive())
|
||||
.withUseNUMA(fromOptions.getUseNUMA())
|
||||
.withNumCtx(fromOptions.getNumCtx())
|
||||
.withNumBatch(fromOptions.getNumBatch())
|
||||
.withNumGQA(fromOptions.getNumGQA())
|
||||
.withNumGPU(fromOptions.getNumGPU())
|
||||
.withMainGPU(fromOptions.getMainGPU())
|
||||
.withLowVRAM(fromOptions.getLowVRAM())
|
||||
.withF16KV(fromOptions.getF16KV())
|
||||
.withLogitsAll(fromOptions.getLogitsAll())
|
||||
.withVocabOnly(fromOptions.getVocabOnly())
|
||||
.withUseMMap(fromOptions.getUseMMap())
|
||||
.withUseMLock(fromOptions.getUseMLock())
|
||||
.withNumThread(fromOptions.getNumThread())
|
||||
.withNumKeep(fromOptions.getNumKeep())
|
||||
.withSeed(fromOptions.getSeed())
|
||||
.withNumPredict(fromOptions.getNumPredict())
|
||||
.withTopK(fromOptions.getTopK())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withTfsZ(fromOptions.getTfsZ())
|
||||
.withTypicalP(fromOptions.getTypicalP())
|
||||
.withRepeatLastN(fromOptions.getRepeatLastN())
|
||||
.withTemperature(fromOptions.getTemperature())
|
||||
.withRepeatPenalty(fromOptions.getRepeatPenalty())
|
||||
.withPresencePenalty(fromOptions.getPresencePenalty())
|
||||
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
|
||||
.withMirostat(fromOptions.getMirostat())
|
||||
.withMirostatTau(fromOptions.getMirostatTau())
|
||||
.withMirostatEta(fromOptions.getMirostatEta())
|
||||
.withPenalizeNewline(fromOptions.getPenalizeNewline())
|
||||
.withStop(fromOptions.getStop());
|
||||
}
|
||||
|
||||
|
||||
// @formatter:on
|
||||
|
||||
|
||||
@@ -56,11 +56,11 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@Testcontainers
|
||||
@Disabled("For manual smoke testing only.")
|
||||
class OllamaChatClientIT {
|
||||
class OllamaChatModelIT {
|
||||
|
||||
private static String MODEL = "mistral";
|
||||
|
||||
private static final Log logger = LogFactory.getLog(OllamaChatClientIT.class);
|
||||
private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class);
|
||||
|
||||
@Container
|
||||
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.1.32");
|
||||
@@ -77,7 +77,7 @@ class OllamaChatClientIT {
|
||||
}
|
||||
|
||||
@Autowired
|
||||
private OllamaChatClient client;
|
||||
private OllamaChatModel chatModel;
|
||||
|
||||
@Test
|
||||
void roleTest() {
|
||||
@@ -95,13 +95,13 @@ class OllamaChatClientIT {
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), portableOptions);
|
||||
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
|
||||
// ollama specific options
|
||||
var ollamaOptions = new OllamaOptions().withLowVRAM(true);
|
||||
|
||||
response = client.call(new Prompt(List.of(userMessage, systemMessage), ollamaOptions));
|
||||
response = chatModel.call(new Prompt(List.of(userMessage, systemMessage), ollamaOptions));
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
|
||||
|
||||
}
|
||||
@@ -109,7 +109,7 @@ class OllamaChatClientIT {
|
||||
@Test
|
||||
void usageTest() {
|
||||
Prompt prompt = new Prompt("Tell me a joke");
|
||||
ChatResponse response = client.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
Usage usage = response.getMetadata().getUsage();
|
||||
|
||||
assertThat(usage).isNotNull();
|
||||
@@ -131,7 +131,7 @@ class OllamaChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors.", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.client.call(prompt).getResult();
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@@ -151,7 +151,7 @@ class OllamaChatClientIT {
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
@@ -173,7 +173,7 @@ class OllamaChatClientIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = client.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
@@ -194,7 +194,7 @@ class OllamaChatClientIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = client.stream(prompt)
|
||||
String generationTextFromStream = chatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
@@ -219,8 +219,8 @@ class OllamaChatClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OllamaChatClient ollamaChat(OllamaApi ollamaApi) {
|
||||
return new OllamaChatClient(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
|
||||
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
|
||||
return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -44,11 +44,11 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
@SpringBootTest
|
||||
@Testcontainers
|
||||
@Disabled("For manual smoke testing only.")
|
||||
class OllamaChatClientMultimodalIT {
|
||||
class OllamaChatModelMultimodalIT {
|
||||
|
||||
private static String MODEL = "llava";
|
||||
|
||||
private static final Log logger = LogFactory.getLog(OllamaChatClientIT.class);
|
||||
private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class);
|
||||
|
||||
@Container
|
||||
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.1.32");
|
||||
@@ -65,7 +65,7 @@ class OllamaChatClientMultimodalIT {
|
||||
}
|
||||
|
||||
@Autowired
|
||||
private OllamaChatClient client;
|
||||
private OllamaChatModel chatModel;
|
||||
|
||||
@Test
|
||||
void multiModalityTest() throws IOException {
|
||||
@@ -75,7 +75,7 @@ class OllamaChatClientMultimodalIT {
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?",
|
||||
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
|
||||
|
||||
var response = client.call(new Prompt(List.of(userMessage)));
|
||||
var response = chatModel.call(new Prompt(List.of(userMessage)));
|
||||
|
||||
logger.info(response.getResult().getOutput().getContent());
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
|
||||
@@ -90,8 +90,8 @@ class OllamaChatClientMultimodalIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OllamaChatClient ollamaChat(OllamaApi ollamaApi) {
|
||||
return new OllamaChatClient(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
|
||||
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
|
||||
return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -30,13 +30,13 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
*/
|
||||
public class OllamaChatRequestTests {
|
||||
|
||||
OllamaChatClient client = new OllamaChatClient(new OllamaApi(),
|
||||
OllamaChatModel chatModel = new OllamaChatModel(new OllamaApi(),
|
||||
new OllamaOptions().withModel("MODEL_NAME").withTopK(99).withTemperature(66.6f).withNumGPU(1));
|
||||
|
||||
@Test
|
||||
public void createRequestWithDefaultOptions() {
|
||||
|
||||
var request = client.ollamaChatRequest(new Prompt("Test message content"), false);
|
||||
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), false);
|
||||
|
||||
assertThat(request.messages()).hasSize(1);
|
||||
assertThat(request.stream()).isFalse();
|
||||
@@ -54,7 +54,7 @@ public class OllamaChatRequestTests {
|
||||
// Runtime options should override the default options.
|
||||
OllamaOptions promptOptions = new OllamaOptions().withTemperature(0.8f).withTopP(0.5f).withNumGPU(2);
|
||||
|
||||
var request = client.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
|
||||
var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
|
||||
|
||||
assertThat(request.messages()).hasSize(1);
|
||||
assertThat(request.stream()).isTrue();
|
||||
@@ -79,7 +79,7 @@ public class OllamaChatRequestTests {
|
||||
.withTopP(0.6f)
|
||||
.build();
|
||||
|
||||
var request = client.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true);
|
||||
var request = chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true);
|
||||
|
||||
assertThat(request.messages()).hasSize(1);
|
||||
assertThat(request.stream()).isTrue();
|
||||
@@ -97,7 +97,7 @@ public class OllamaChatRequestTests {
|
||||
// Ollama runtime options.
|
||||
OllamaOptions promptOptions = new OllamaOptions().withModel("PROMPT_MODEL");
|
||||
|
||||
var request = client.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
|
||||
var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
|
||||
|
||||
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
|
||||
}
|
||||
@@ -105,17 +105,17 @@ public class OllamaChatRequestTests {
|
||||
@Test
|
||||
public void createRequestWithDefaultOptionsModelOverride() {
|
||||
|
||||
OllamaChatClient client2 = new OllamaChatClient(new OllamaApi(),
|
||||
OllamaChatModel chatModel = new OllamaChatModel(new OllamaApi(),
|
||||
new OllamaOptions().withModel("DEFAULT_OPTIONS_MODEL"));
|
||||
|
||||
var request = client2.ollamaChatRequest(new Prompt("Test message content"), true);
|
||||
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);
|
||||
|
||||
assertThat(request.model()).isEqualTo("DEFAULT_OPTIONS_MODEL");
|
||||
|
||||
// Prompt options should override the default options.
|
||||
OllamaOptions promptOptions = new OllamaOptions().withModel("PROMPT_MODEL");
|
||||
|
||||
request = client2.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
|
||||
request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
|
||||
|
||||
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
|
||||
}
|
||||
|
||||
@@ -23,9 +23,9 @@ import org.apache.commons.logging.LogFactory;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.testcontainers.containers.GenericContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
import org.testcontainers.ollama.OllamaContainer;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.ai.ollama.api.OllamaApi;
|
||||
@@ -34,14 +34,13 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.testcontainers.ollama.OllamaContainer;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest
|
||||
@Disabled("For manual smoke testing only.")
|
||||
@Testcontainers
|
||||
class OllamaEmbeddingClientIT {
|
||||
class OllamaEmbeddingModelIT {
|
||||
|
||||
private static final Log logger = LogFactory.getLog(OllamaApiIT.class);
|
||||
|
||||
@@ -60,15 +59,15 @@ class OllamaEmbeddingClientIT {
|
||||
}
|
||||
|
||||
@Autowired
|
||||
private OllamaEmbeddingClient embeddingClient;
|
||||
private OllamaEmbeddingModel embeddingModel;
|
||||
|
||||
@Test
|
||||
void singleEmbedding() {
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World"));
|
||||
assertThat(embeddingResponse.getResults()).hasSize(1);
|
||||
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
|
||||
assertThat(embeddingClient.dimensions()).isEqualTo(3200);
|
||||
assertThat(embeddingModel.dimensions()).isEqualTo(3200);
|
||||
}
|
||||
|
||||
@SpringBootConfiguration
|
||||
@@ -80,8 +79,8 @@ class OllamaEmbeddingClientIT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OllamaEmbeddingClient ollamaEmbedding(OllamaApi ollamaApi) {
|
||||
return new OllamaEmbeddingClient(ollamaApi).withModel("orca-mini");
|
||||
public OllamaEmbeddingModel ollamaEmbedding(OllamaApi ollamaApi) {
|
||||
return new OllamaEmbeddingModel(ollamaApi).withModel("orca-mini");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -28,13 +28,13 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
*/
|
||||
public class OllamaEmbeddingRequestTests {
|
||||
|
||||
OllamaEmbeddingClient client = new OllamaEmbeddingClient(new OllamaApi()).withDefaultOptions(
|
||||
OllamaEmbeddingModel chatModel = new OllamaEmbeddingModel(new OllamaApi()).withDefaultOptions(
|
||||
new OllamaOptions().withModel("DEFAULT_MODEL").withMainGPU(11).withUseMMap(true).withNumGPU(1));
|
||||
|
||||
@Test
|
||||
public void ollamaEmbeddingRequestDefaultOptions() {
|
||||
|
||||
var request = client.ollamaEmbeddingRequest("Hello", null);
|
||||
var request = chatModel.ollamaEmbeddingRequest("Hello", null);
|
||||
|
||||
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
|
||||
assertThat(request.options().get("num_gpu")).isEqualTo(1);
|
||||
@@ -51,7 +51,7 @@ public class OllamaEmbeddingRequestTests {
|
||||
.withUseMMap(true)
|
||||
.withNumGPU(2);
|
||||
|
||||
var request = client.ollamaEmbeddingRequest("Hello", promptOptions);
|
||||
var request = chatModel.ollamaEmbeddingRequest("Hello", promptOptions);
|
||||
|
||||
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
|
||||
assertThat(request.options().get("num_gpu")).isEqualTo(2);
|
||||
|
||||
@@ -24,10 +24,10 @@ import org.springframework.ai.openai.api.OpenAiAudioApi;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
|
||||
import org.springframework.ai.openai.api.common.OpenAiApiException;
|
||||
import org.springframework.ai.openai.audio.speech.Speech;
|
||||
import org.springframework.ai.openai.audio.speech.SpeechClient;
|
||||
import org.springframework.ai.openai.audio.speech.SpeechModel;
|
||||
import org.springframework.ai.openai.audio.speech.SpeechPrompt;
|
||||
import org.springframework.ai.openai.audio.speech.SpeechResponse;
|
||||
import org.springframework.ai.openai.audio.speech.StreamingSpeechClient;
|
||||
import org.springframework.ai.openai.audio.speech.StreamingSpeechModel;
|
||||
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
|
||||
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
@@ -44,7 +44,7 @@ import java.time.Duration;
|
||||
* @see OpenAiAudioApi
|
||||
* @since 1.0.0-M1
|
||||
*/
|
||||
public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechClient {
|
||||
public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@@ -61,12 +61,12 @@ public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechCli
|
||||
private final OpenAiAudioApi audioApi;
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiAudioSpeechClient class with the provided
|
||||
* Initializes a new instance of the OpenAiAudioSpeechModel class with the provided
|
||||
* OpenAiAudioApi. It uses the model tts-1, response format mp3, voice alloy, and the
|
||||
* default speed of 1.0.
|
||||
* @param audioApi The OpenAiAudioApi to use for speech synthesis.
|
||||
*/
|
||||
public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi) {
|
||||
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi) {
|
||||
this(audioApi,
|
||||
OpenAiAudioSpeechOptions.builder()
|
||||
.withModel(OpenAiAudioApi.TtsModel.TTS_1.getValue())
|
||||
@@ -77,13 +77,13 @@ public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechCli
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiAudioSpeechClient class with the provided
|
||||
* Initializes a new instance of the OpenAiAudioSpeechModel class with the provided
|
||||
* OpenAiAudioApi and options.
|
||||
* @param audioApi The OpenAiAudioApi to use for speech synthesis.
|
||||
* @param options The OpenAiAudioSpeechOptions containing the speech synthesis
|
||||
* options.
|
||||
*/
|
||||
public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
|
||||
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
|
||||
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
|
||||
Assert.notNull(options, "OpenAiSpeechOptions must not be null");
|
||||
this.audioApi = audioApi;
|
||||
@@ -35,7 +35,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.model.ModelClient;
|
||||
import org.springframework.ai.model.Model;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
|
||||
import org.springframework.ai.openai.audio.transcription.AudioTranscription;
|
||||
@@ -59,8 +59,7 @@ import org.springframework.util.Assert;
|
||||
* @see OpenAiAudioApi
|
||||
* @since 0.8.1
|
||||
*/
|
||||
public class OpenAiAudioTranscriptionClient
|
||||
implements ModelClient<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
|
||||
public class OpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@@ -71,11 +70,11 @@ public class OpenAiAudioTranscriptionClient
|
||||
private final OpenAiAudioApi audioApi;
|
||||
|
||||
/**
|
||||
* OpenAiAudioTranscriptionClient is a client class used to interact with the OpenAI
|
||||
* OpenAiAudioTranscriptionModel is a client class used to interact with the OpenAI
|
||||
* Audio Transcription API.
|
||||
* @param audioApi The OpenAiAudioApi instance to be used for making API calls.
|
||||
*/
|
||||
public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi) {
|
||||
public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi) {
|
||||
this(audioApi,
|
||||
OpenAiAudioTranscriptionOptions.builder()
|
||||
.withModel(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue())
|
||||
@@ -86,25 +85,25 @@ public class OpenAiAudioTranscriptionClient
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAiAudioTranscriptionClient is a client class used to interact with the OpenAI
|
||||
* OpenAiAudioTranscriptionModel is a client class used to interact with the OpenAI
|
||||
* Audio Transcription API.
|
||||
* @param audioApi The OpenAiAudioApi instance to be used for making API calls.
|
||||
* @param options The OpenAiAudioTranscriptionOptions instance for configuring the
|
||||
* audio transcription.
|
||||
*/
|
||||
public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options) {
|
||||
public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options) {
|
||||
this(audioApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAiAudioTranscriptionClient is a client class used to interact with the OpenAI
|
||||
* OpenAiAudioTranscriptionModel is a client class used to interact with the OpenAI
|
||||
* Audio Transcription API.
|
||||
* @param audioApi The OpenAiAudioApi instance to be used for making API calls.
|
||||
* @param options The OpenAiAudioTranscriptionOptions instance for configuring the
|
||||
* audio transcription.
|
||||
* @param retryTemplate The RetryTemplate instance for retrying failed API calls.
|
||||
*/
|
||||
public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options,
|
||||
public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options,
|
||||
RetryTemplate retryTemplate) {
|
||||
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
|
||||
Assert.notNull(options, "OpenAiTranscriptionOptions must not be null");
|
||||
@@ -17,10 +17,10 @@ package org.springframework.ai.openai;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.StreamingChatClient;
|
||||
import org.springframework.ai.chat.StreamingChatModel;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
@@ -58,7 +58,7 @@ import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* {@link ChatClient} and {@link StreamingChatClient} implementation for {@literal OpenAI}
|
||||
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
|
||||
* backed by {@link OpenAiApi}.
|
||||
*
|
||||
* @author Mark Pollack
|
||||
@@ -68,15 +68,15 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
* @author Josh Long
|
||||
* @author Jemin Huh
|
||||
* @author Grogdunn
|
||||
* @see ChatClient
|
||||
* @see StreamingChatClient
|
||||
* @see ChatModel
|
||||
* @see StreamingChatModel
|
||||
* @see OpenAiApi
|
||||
*/
|
||||
public class OpenAiChatClient extends
|
||||
public class OpenAiChatModel extends
|
||||
AbstractFunctionCallSupport<ChatCompletionMessage, OpenAiApi.ChatCompletionRequest, ResponseEntity<ChatCompletion>>
|
||||
implements ChatClient, StreamingChatClient {
|
||||
implements ChatModel, StreamingChatModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClient.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class);
|
||||
|
||||
/**
|
||||
* The default options used for the chat completion requests.
|
||||
@@ -94,35 +94,35 @@ public class OpenAiChatClient extends
|
||||
private final OpenAiApi openAiApi;
|
||||
|
||||
/**
|
||||
* Creates an instance of the OpenAiChatClient.
|
||||
* Creates an instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @throws IllegalArgumentException if openAiApi is null
|
||||
*/
|
||||
public OpenAiChatClient(OpenAiApi openAiApi) {
|
||||
public OpenAiChatModel(OpenAiApi openAiApi) {
|
||||
this(openAiApi,
|
||||
OpenAiChatOptions.builder().withModel(OpenAiApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes an instance of the OpenAiChatClient.
|
||||
* Initializes an instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @param options The OpenAiChatOptions to configure the chat client.
|
||||
* @param options The OpenAiChatOptions to configure the chat model.
|
||||
*/
|
||||
public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options) {
|
||||
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) {
|
||||
this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiChatClient.
|
||||
* Initializes a new instance of the OpenAiChatModel.
|
||||
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
|
||||
* Chat API.
|
||||
* @param options The OpenAiChatOptions to configure the chat client.
|
||||
* @param options The OpenAiChatOptions to configure the chat model.
|
||||
* @param functionCallbackContext The function callback context.
|
||||
* @param retryTemplate The retry template.
|
||||
*/
|
||||
public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options,
|
||||
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
|
||||
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
|
||||
super(functionCallbackContext);
|
||||
Assert.notNull(openAiApi, "OpenAiApi must not be null");
|
||||
@@ -394,4 +394,9 @@ public class OpenAiChatClient extends
|
||||
&& choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return OpenAiChatOptions.fromOptions(this.defaultOptions);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -134,10 +134,10 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
|
||||
private @JsonProperty("user") String user;
|
||||
|
||||
/**
|
||||
* OpenAI Tool Function Callbacks to register with the ChatClient.
|
||||
* OpenAI Tool Function Callbacks to register with the ChatModel.
|
||||
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
|
||||
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
|
||||
* from the registry to be used by the ChatClient chat completion requests.
|
||||
* from the registry to be used by the ChatModel chat completion requests.
|
||||
*/
|
||||
@NestedConfigurationProperty
|
||||
@JsonIgnore
|
||||
@@ -567,4 +567,27 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
|
||||
throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
|
||||
}
|
||||
|
||||
public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
|
||||
return OpenAiChatOptions.builder()
|
||||
.withModel(fromOptions.getModel())
|
||||
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
|
||||
.withLogitBias(fromOptions.getLogitBias())
|
||||
.withLogprobs(fromOptions.getLogprobs())
|
||||
.withTopLogprobs(fromOptions.getTopLogprobs())
|
||||
.withMaxTokens(fromOptions.getMaxTokens())
|
||||
.withN(fromOptions.getN())
|
||||
.withPresencePenalty(fromOptions.getPresencePenalty())
|
||||
.withResponseFormat(fromOptions.getResponseFormat())
|
||||
.withSeed(fromOptions.getSeed())
|
||||
.withStop(fromOptions.getStop())
|
||||
.withTemperature(fromOptions.getTemperature())
|
||||
.withTopP(fromOptions.getTopP())
|
||||
.withTools(fromOptions.getTools())
|
||||
.withToolChoice(fromOptions.getToolChoice())
|
||||
.withUser(fromOptions.getUser())
|
||||
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
|
||||
.withFunctions(fromOptions.getFunctions())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.document.MetadataMode;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingClient;
|
||||
import org.springframework.ai.embedding.AbstractEmbeddingModel;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingOptions;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
@@ -41,9 +41,9 @@ import org.springframework.util.Assert;
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class OpenAiEmbeddingClient extends AbstractEmbeddingClient {
|
||||
public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingClient.class);
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingModel.class);
|
||||
|
||||
private final OpenAiEmbeddingOptions defaultOptions;
|
||||
|
||||
@@ -54,43 +54,43 @@ public class OpenAiEmbeddingClient extends AbstractEmbeddingClient {
|
||||
private final MetadataMode metadataMode;
|
||||
|
||||
/**
|
||||
* Constructor for the OpenAiEmbeddingClient class.
|
||||
* Constructor for the OpenAiEmbeddingModel class.
|
||||
* @param openAiApi The OpenAiApi instance to use for making API requests.
|
||||
*/
|
||||
public OpenAiEmbeddingClient(OpenAiApi openAiApi) {
|
||||
public OpenAiEmbeddingModel(OpenAiApi openAiApi) {
|
||||
this(openAiApi, MetadataMode.EMBED);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiEmbeddingClient class.
|
||||
* Initializes a new instance of the OpenAiEmbeddingModel class.
|
||||
* @param openAiApi The OpenAiApi instance to use for making API requests.
|
||||
* @param metadataMode The mode for generating metadata.
|
||||
*/
|
||||
public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode) {
|
||||
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode) {
|
||||
this(openAiApi, metadataMode,
|
||||
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
|
||||
RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiEmbeddingClient class.
|
||||
* Initializes a new instance of the OpenAiEmbeddingModel class.
|
||||
* @param openAiApi The OpenAiApi instance to use for making API requests.
|
||||
* @param metadataMode The mode for generating metadata.
|
||||
* @param openAiEmbeddingOptions The options for OpenAi embedding.
|
||||
*/
|
||||
public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode,
|
||||
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode,
|
||||
OpenAiEmbeddingOptions openAiEmbeddingOptions) {
|
||||
this(openAiApi, metadataMode, openAiEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a new instance of the OpenAiEmbeddingClient class.
|
||||
* Initializes a new instance of the OpenAiEmbeddingModel class.
|
||||
* @param openAiApi - The OpenAiApi instance to use for making API requests.
|
||||
* @param metadataMode - The mode for generating metadata.
|
||||
* @param options - The options for OpenAI embedding.
|
||||
* @param retryTemplate - The RetryTemplate for retrying failed API requests.
|
||||
*/
|
||||
public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
|
||||
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
|
||||
RetryTemplate retryTemplate) {
|
||||
Assert.notNull(openAiApi, "OpenAiService must not be null");
|
||||
Assert.notNull(metadataMode, "metadataMode must not be null");
|
||||
@@ -21,7 +21,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.image.Image;
|
||||
import org.springframework.ai.image.ImageClient;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.image.ImageGeneration;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
@@ -37,16 +37,16 @@ import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* OpenAiImageClient is a class that implements the ImageClient interface. It provides a
|
||||
* OpenAiImageModel is a class that implements the ImageModel interface. It provides a
|
||||
* client for calling the OpenAI image generation API.
|
||||
*
|
||||
* @author Mark Pollack
|
||||
* @author Christian Tzolov
|
||||
* @since 0.8.0
|
||||
*/
|
||||
public class OpenAiImageClient implements ImageClient {
|
||||
public class OpenAiImageModel implements ImageModel {
|
||||
|
||||
private final static Logger logger = LoggerFactory.getLogger(OpenAiImageClient.class);
|
||||
private final static Logger logger = LoggerFactory.getLogger(OpenAiImageModel.class);
|
||||
|
||||
private OpenAiImageOptions defaultOptions;
|
||||
|
||||
@@ -54,11 +54,11 @@ public class OpenAiImageClient implements ImageClient {
|
||||
|
||||
public final RetryTemplate retryTemplate;
|
||||
|
||||
public OpenAiImageClient(OpenAiImageApi openAiImageApi) {
|
||||
public OpenAiImageModel(OpenAiImageApi openAiImageApi) {
|
||||
this(openAiImageApi, OpenAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public OpenAiImageClient(OpenAiImageApi openAiImageApi, OpenAiImageOptions defaultOptions,
|
||||
public OpenAiImageModel(OpenAiImageApi openAiImageApi, OpenAiImageOptions defaultOptions,
|
||||
RetryTemplate retryTemplate) {
|
||||
Assert.notNull(openAiImageApi, "OpenAiImageApi must not be null");
|
||||
Assert.notNull(defaultOptions, "defaultOptions must not be null");
|
||||
@@ -26,6 +26,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import org.springframework.ai.model.ModelDescription;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.boot.context.properties.bind.ConstructorBinding;
|
||||
@@ -113,7 +114,7 @@ public class OpenAiApi {
|
||||
* - <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">GPT-4 and GPT-4 Turbo</a>
|
||||
* - <a href="https://platform.openai.com/docs/models/gpt-3-5-turbo">GPT-3.5 Turbo</a>.
|
||||
*/
|
||||
public enum ChatModel {
|
||||
public enum ChatModel implements ModelDescription {
|
||||
/**
|
||||
* Multimodal flagship model that’s cheaper and faster than GPT-4 Turbo.
|
||||
* Currently points to gpt-4o-2024-05-13.
|
||||
@@ -199,6 +200,11 @@ public class OpenAiApi {
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModelName() {
|
||||
return this.value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
|
||||
package org.springframework.ai.openai.audio.speech;
|
||||
|
||||
import org.springframework.ai.model.ModelClient;
|
||||
import org.springframework.ai.model.Model;
|
||||
|
||||
/**
|
||||
* The {@link SpeechClient} interface provides a way to interact with the OpenAI
|
||||
* The {@link SpeechModel} interface provides a way to interact with the OpenAI
|
||||
* Text-to-Speech (TTS) API. It allows you to convert text input into lifelike spoken
|
||||
* audio.
|
||||
*
|
||||
@@ -27,7 +27,7 @@ import org.springframework.ai.model.ModelClient;
|
||||
* @since 1.0.0-M1
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface SpeechClient extends ModelClient<SpeechPrompt, SpeechResponse> {
|
||||
public interface SpeechModel extends Model<SpeechPrompt, SpeechResponse> {
|
||||
|
||||
/**
|
||||
* Generates spoken audio from the provided text message.
|
||||
@@ -16,11 +16,11 @@
|
||||
|
||||
package org.springframework.ai.openai.audio.speech;
|
||||
|
||||
import org.springframework.ai.model.StreamingModelClient;
|
||||
import org.springframework.ai.model.StreamingModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* The {@link StreamingSpeechClient} interface provides a way to interact with the OpenAI
|
||||
* The {@link StreamingSpeechModel} interface provides a way to interact with the OpenAI
|
||||
* Text-to-Speech (TTS) API using a streaming approach, allowing you to receive the
|
||||
* generated audio in a real-time fashion.
|
||||
*
|
||||
@@ -28,7 +28,7 @@ import reactor.core.publisher.Flux;
|
||||
* @since 1.0.0-M1
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface StreamingSpeechClient extends StreamingModelClient<SpeechPrompt, SpeechResponse> {
|
||||
public interface StreamingSpeechModel extends StreamingModel<SpeechPrompt, SpeechResponse> {
|
||||
|
||||
/**
|
||||
* Generates a stream of audio bytes from the provided text message.
|
||||
@@ -34,7 +34,7 @@ public class ChatCompletionRequestTests {
|
||||
@Test
|
||||
public void createRequestWithChatOptions() {
|
||||
|
||||
var client = new OpenAiChatClient(new OpenAiApi("TEST"),
|
||||
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
|
||||
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
|
||||
|
||||
var request = client.createRequest(new Prompt("Test message content"), false);
|
||||
@@ -60,7 +60,7 @@ public class ChatCompletionRequestTests {
|
||||
|
||||
final String TOOL_FUNCTION_NAME = "CurrentWeather";
|
||||
|
||||
var client = new OpenAiChatClient(new OpenAiApi("TEST"),
|
||||
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
|
||||
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").build());
|
||||
|
||||
var request = client.createRequest(new Prompt("Test message content",
|
||||
@@ -90,7 +90,7 @@ public class ChatCompletionRequestTests {
|
||||
|
||||
final String TOOL_FUNCTION_NAME = "CurrentWeather";
|
||||
|
||||
var client = new OpenAiChatClient(new OpenAiApi("TEST"),
|
||||
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
|
||||
OpenAiChatOptions.builder()
|
||||
.withModel("DEFAULT_MODEL")
|
||||
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
*/
|
||||
package org.springframework.ai.openai;
|
||||
|
||||
import org.springframework.ai.embedding.EmbeddingClient;
|
||||
import org.springframework.ai.embedding.EmbeddingModel;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
@@ -51,33 +51,33 @@ public class OpenAiTestConfiguration {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatClient openAiChatClient(OpenAiApi api) {
|
||||
OpenAiChatClient openAiChatClient = new OpenAiChatClient(api);
|
||||
return openAiChatClient;
|
||||
public OpenAiChatModel openAiChatModel(OpenAiApi api) {
|
||||
OpenAiChatModel openAiChatModel = new OpenAiChatModel(api);
|
||||
return openAiChatModel;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiAudioTranscriptionClient openAiTranscriptionClient(OpenAiAudioApi api) {
|
||||
OpenAiAudioTranscriptionClient openAiTranscriptionClient = new OpenAiAudioTranscriptionClient(api);
|
||||
return openAiTranscriptionClient;
|
||||
public OpenAiAudioTranscriptionModel openAiTranscriptionModel(OpenAiAudioApi api) {
|
||||
OpenAiAudioTranscriptionModel openAiTranscriptionModel = new OpenAiAudioTranscriptionModel(api);
|
||||
return openAiTranscriptionModel;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiAudioSpeechClient openAiAudioSpeechClient(OpenAiAudioApi api) {
|
||||
OpenAiAudioSpeechClient openAiAudioSpeechClient = new OpenAiAudioSpeechClient(api);
|
||||
return openAiAudioSpeechClient;
|
||||
public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) {
|
||||
OpenAiAudioSpeechModel openAiAudioSpeechModel = new OpenAiAudioSpeechModel(api);
|
||||
return openAiAudioSpeechModel;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiImageClient openAiImageClient(OpenAiImageApi imageApi) {
|
||||
OpenAiImageClient openAiImageClient = new OpenAiImageClient(imageApi);
|
||||
// openAiImageClient.setModel("foobar");
|
||||
return openAiImageClient;
|
||||
public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) {
|
||||
OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi);
|
||||
// openAiImageModel.setModel("foobar");
|
||||
return openAiImageModel;
|
||||
}
|
||||
|
||||
@Bean
|
||||
public EmbeddingClient openAiEmbeddingClient(OpenAiApi api) {
|
||||
return new OpenAiEmbeddingClient(api);
|
||||
public EmbeddingModel openAiEmbeddingModel(OpenAiApi api) {
|
||||
return new OpenAiEmbeddingModel(api);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ public class TranscriptionRequestTests {
|
||||
@Test
|
||||
public void defaultOptions() {
|
||||
|
||||
var client = new OpenAiAudioTranscriptionClient(new OpenAiAudioApi("TEST"),
|
||||
var client = new OpenAiAudioTranscriptionModel(new OpenAiAudioApi("TEST"),
|
||||
OpenAiAudioTranscriptionOptions.builder()
|
||||
.withModel("DEFAULT_MODEL")
|
||||
.withResponseFormat(TranscriptResponseFormat.TEXT)
|
||||
@@ -58,7 +58,7 @@ public class TranscriptionRequestTests {
|
||||
@Test
|
||||
public void runtimeOptions() {
|
||||
|
||||
var client = new OpenAiAudioTranscriptionClient(new OpenAiAudioApi("TEST"),
|
||||
var client = new OpenAiAudioTranscriptionModel(new OpenAiAudioApi("TEST"),
|
||||
OpenAiAudioTranscriptionOptions.builder()
|
||||
.withModel("DEFAULT_MODEL")
|
||||
.withResponseFormat(TranscriptResponseFormat.TEXT)
|
||||
|
||||
@@ -26,9 +26,9 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiTestConfiguration;
|
||||
import org.springframework.ai.openai.OpenAiChatClient;
|
||||
import org.springframework.ai.openai.OpenAiEmbeddingClient;
|
||||
import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
||||
import org.springframework.ai.openai.testutils.AbstractIT;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
@@ -58,16 +58,16 @@ public class AcmeIT extends AbstractIT {
|
||||
private Resource systemBikePrompt;
|
||||
|
||||
@Autowired
|
||||
private OpenAiEmbeddingClient embeddingClient;
|
||||
private OpenAiEmbeddingModel embeddingModel;
|
||||
|
||||
@Autowired
|
||||
private OpenAiChatClient chatClient;
|
||||
private OpenAiChatModel chatModel;
|
||||
|
||||
@Test
|
||||
void beanTest() {
|
||||
assertThat(bikesResource).isNotNull();
|
||||
assertThat(embeddingClient).isNotNull();
|
||||
assertThat(chatClient).isNotNull();
|
||||
assertThat(embeddingModel).isNotNull();
|
||||
assertThat(chatModel).isNotNull();
|
||||
}
|
||||
|
||||
// @Test
|
||||
@@ -81,7 +81,7 @@ public class AcmeIT extends AbstractIT {
|
||||
// Step 2 - Create embeddings and save to vector store
|
||||
|
||||
logger.info("Creating Embeddings...");
|
||||
VectorStore vectorStore = new SimpleVectorStore(embeddingClient);
|
||||
VectorStore vectorStore = new SimpleVectorStore(embeddingModel);
|
||||
|
||||
vectorStore.accept(textSplitter.apply(jsonReader.get()));
|
||||
|
||||
@@ -108,7 +108,7 @@ public class AcmeIT extends AbstractIT {
|
||||
logger.info("Asking AI generative to reply to question.");
|
||||
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
|
||||
logger.info("AI responded.");
|
||||
ChatResponse response = chatClient.call(prompt);
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
|
||||
evaluateQuestionAndAnswer(userQuery, response, true);
|
||||
}
|
||||
|
||||
@@ -32,13 +32,13 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = OpenAiTestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
class OpenAiSpeechClientIT extends AbstractIT {
|
||||
class OpenAiSpeechModelIT extends AbstractIT {
|
||||
|
||||
private static final Float SPEED = 1.0f;
|
||||
|
||||
@Test
|
||||
void shouldSuccessfullyStreamAudioBytesForEmptyMessage() {
|
||||
Flux<byte[]> response = speechClient.stream("Today is a wonderful day to build something people love!");
|
||||
Flux<byte[]> response = speechModel.stream("Today is a wonderful day to build something people love!");
|
||||
assertThat(response).isNotNull();
|
||||
assertThat(response.collectList().block()).isNotNull();
|
||||
System.out.println(response.collectList().block());
|
||||
@@ -46,7 +46,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
|
||||
|
||||
@Test
|
||||
void shouldProduceAudioBytesDirectlyFromMessage() {
|
||||
byte[] audioBytes = speechClient.call("Today is a wonderful day to build something people love!");
|
||||
byte[] audioBytes = speechModel.call("Today is a wonderful day to build something people love!");
|
||||
assertThat(audioBytes).hasSizeGreaterThan(0);
|
||||
|
||||
}
|
||||
@@ -61,7 +61,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
|
||||
.build();
|
||||
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
|
||||
speechOptions);
|
||||
SpeechResponse response = speechClient.call(speechPrompt);
|
||||
SpeechResponse response = speechModel.call(speechPrompt);
|
||||
byte[] audioBytes = response.getResult().getOutput();
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput()).isNotEmpty();
|
||||
@@ -79,7 +79,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
|
||||
.build();
|
||||
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
|
||||
speechOptions);
|
||||
SpeechResponse response = speechClient.call(speechPrompt);
|
||||
SpeechResponse response = speechModel.call(speechPrompt);
|
||||
OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata();
|
||||
assertThat(metadata).isNotNull();
|
||||
assertThat(metadata.getRateLimit()).isNotNull();
|
||||
@@ -100,7 +100,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
|
||||
|
||||
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
|
||||
speechOptions);
|
||||
Flux<SpeechResponse> responseFlux = speechClient.stream(speechPrompt);
|
||||
Flux<SpeechResponse> responseFlux = speechModel.stream(speechPrompt);
|
||||
assertThat(responseFlux).isNotNull();
|
||||
List<SpeechResponse> responses = responseFlux.collectList().block();
|
||||
assertThat(responses).isNotNull();
|
||||
@@ -18,7 +18,7 @@ package org.springframework.ai.openai.audio.speech;
|
||||
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.openai.OpenAiAudioSpeechClient;
|
||||
import org.springframework.ai.openai.OpenAiAudioSpeechModel;
|
||||
import org.springframework.ai.openai.OpenAiAudioSpeechOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi;
|
||||
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
|
||||
@@ -45,15 +45,15 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat
|
||||
/**
|
||||
* @author Ahmed Yousri
|
||||
*/
|
||||
@RestClientTest(OpenAiSpeechClientWithSpeechResponseMetadataTests.Config.class)
|
||||
public class OpenAiSpeechClientWithSpeechResponseMetadataTests {
|
||||
@RestClientTest(OpenAiSpeechModelWithSpeechResponseMetadataTests.Config.class)
|
||||
public class OpenAiSpeechModelWithSpeechResponseMetadataTests {
|
||||
|
||||
private static String TEST_API_KEY = "sk-1234567890";
|
||||
|
||||
private static final Float SPEED = 1.0f;
|
||||
|
||||
@Autowired
|
||||
private OpenAiAudioSpeechClient openAiSpeechClient;
|
||||
private OpenAiAudioSpeechModel openAiSpeechClient;
|
||||
|
||||
@Autowired
|
||||
private MockRestServiceServer server;
|
||||
@@ -121,8 +121,8 @@ public class OpenAiSpeechClientWithSpeechResponseMetadataTests {
|
||||
static class Config {
|
||||
|
||||
@Bean
|
||||
public OpenAiAudioSpeechClient openAiAudioSpeechClient(OpenAiAudioApi openAiAudioApi) {
|
||||
return new OpenAiAudioSpeechClient(openAiAudioApi);
|
||||
public OpenAiAudioSpeechModel openAiAudioSpeechClient(OpenAiAudioApi openAiAudioApi) {
|
||||
return new OpenAiAudioSpeechModel(openAiAudioApi);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@@ -31,7 +31,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = OpenAiTestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
class OpenAiTranscriptionClientIT extends AbstractIT {
|
||||
class OpenAiTranscriptionModelIT extends AbstractIT {
|
||||
|
||||
@Value("classpath:/speech/jfk.flac")
|
||||
private Resource audioFile;
|
||||
@@ -43,7 +43,7 @@ class OpenAiTranscriptionClientIT extends AbstractIT {
|
||||
.withTemperature(0f)
|
||||
.build();
|
||||
AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions);
|
||||
AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest);
|
||||
AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
|
||||
}
|
||||
@@ -59,7 +59,7 @@ class OpenAiTranscriptionClientIT extends AbstractIT {
|
||||
.withResponseFormat(responseFormat)
|
||||
.build();
|
||||
AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions);
|
||||
AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest);
|
||||
AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
|
||||
}
|
||||
@@ -21,7 +21,7 @@ import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.metadata.RateLimit;
|
||||
import org.springframework.ai.openai.OpenAiAudioTranscriptionClient;
|
||||
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
|
||||
import org.springframework.ai.openai.api.OpenAiAudioApi;
|
||||
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionMetadata;
|
||||
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata;
|
||||
@@ -48,13 +48,13 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat
|
||||
/**
|
||||
* @author Michael Lavelle
|
||||
*/
|
||||
@RestClientTest(OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.Config.class)
|
||||
public class OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests {
|
||||
@RestClientTest(OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.Config.class)
|
||||
public class OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests {
|
||||
|
||||
private static String TEST_API_KEY = "sk-1234567890";
|
||||
|
||||
@Autowired
|
||||
private OpenAiAudioTranscriptionClient openAiTranscriptionClient;
|
||||
private OpenAiAudioTranscriptionModel openAiTranscriptionClient;
|
||||
|
||||
@Autowired
|
||||
private MockRestServiceServer server;
|
||||
@@ -156,8 +156,8 @@ public class OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiAudioTranscriptionClient openAiClient(OpenAiAudioApi openAiAudioApi) {
|
||||
return new OpenAiAudioTranscriptionClient(openAiAudioApi);
|
||||
public OpenAiAudioTranscriptionModel openAiClient(OpenAiAudioApi openAiAudioApi) {
|
||||
return new OpenAiAudioTranscriptionModel(openAiAudioApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,7 +18,7 @@ package org.springframework.ai.openai.audio.transcription;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import org.springframework.ai.openai.OpenAiAudioTranscriptionClient;
|
||||
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
@@ -33,18 +33,18 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
/**
|
||||
* Unit Tests for {@link TranscriptionClient}.
|
||||
* Unit Tests for {@link TranscriptionModel}.
|
||||
*
|
||||
* @author Michael Lavelle
|
||||
*/
|
||||
class TranscriptionClientTests {
|
||||
class TranscriptionModelTests {
|
||||
|
||||
@Test
|
||||
void transcrbeRequestReturnsResponseCorrectly() {
|
||||
|
||||
Resource mockAudioFile = Mockito.mock(Resource.class);
|
||||
|
||||
OpenAiAudioTranscriptionClient mockClient = Mockito.mock(OpenAiAudioTranscriptionClient.class);
|
||||
OpenAiAudioTranscriptionModel mockClient = Mockito.mock(OpenAiAudioTranscriptionModel.class);
|
||||
|
||||
String mockTranscription = "All your bases are belong to us";
|
||||
|
||||
@@ -17,8 +17,8 @@ package org.springframework.ai.openai.chat;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -31,19 +31,10 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.ChatClient;
|
||||
import org.springframework.ai.chat.ChatModel;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.converter.MapOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallbackWrapper;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.OpenAiTestConfiguration;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
@@ -51,7 +42,7 @@ import org.springframework.ai.openai.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.openai.testutils.AbstractIT;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.convert.support.DefaultConversionService;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
@@ -65,75 +56,96 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClientIT.class);
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
private Resource systemTextResource;
|
||||
|
||||
@Test
|
||||
void roleTest() {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
ChatResponse response = chatClient.call(prompt);
|
||||
|
||||
// @formatter:off
|
||||
ChatResponse response = ChatClient.builder(chatModel).build().prompt()
|
||||
.system(s -> s.text(systemTextResource)
|
||||
.param("name", "Bob")
|
||||
.param("voice", "pirate"))
|
||||
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
|
||||
.call()
|
||||
.chatResponse();
|
||||
// @formatter:on
|
||||
|
||||
logger.info("" + response);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
|
||||
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
void listOutputConverter() {
|
||||
DefaultConversionService conversionService = new DefaultConversionService();
|
||||
ListOutputConverter outputConverter = new ListOutputConverter(conversionService);
|
||||
// @formatter:off
|
||||
Collection<String> collection = ChatClient.builder(chatModel).build().prompt()
|
||||
.user(u -> u.text("List five {subject}")
|
||||
.param("subject", "ice cream flavors"))
|
||||
.call()
|
||||
.list(String.class);
|
||||
// @formatter:on
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
List five {subject}
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatClient.call(prompt).getResult();
|
||||
assertThat(collection).hasSize(5);
|
||||
}
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
@Test
|
||||
void listOutputConverter2() {
|
||||
|
||||
// @formatter:off
|
||||
List<ActorsFilmsRecord> actorsFilms = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
|
||||
.call()
|
||||
.single(new ParameterizedTypeReference<List<ActorsFilmsRecord>>() {
|
||||
});
|
||||
// @formatter:on
|
||||
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms).hasSize(2);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void listOutputConverter3() {
|
||||
|
||||
// @formatter:off
|
||||
Collection<ActorsFilmsRecord> actorsFilms = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
|
||||
.call()
|
||||
.list(ActorsFilmsRecord.class);
|
||||
// @formatter:on
|
||||
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms).hasSize(2);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void mapOutputConverter() {
|
||||
MapOutputConverter outputConverter = new MapOutputConverter();
|
||||
// @formatter:off
|
||||
Map<String, Object> result = ChatClient.builder(chatModel).build().prompt()
|
||||
.user(u -> u.text("Provide me a List of {subject}")
|
||||
.param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'"))
|
||||
.call()
|
||||
.single(new ParameterizedTypeReference<Map<String, Object>>() {
|
||||
});
|
||||
// @formatter:on
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Provide me a List of {subject}
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanOutputConverter() {
|
||||
|
||||
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
|
||||
// @formatter:off
|
||||
ActorsFilms actorsFilms = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("Generate the filmography for a random actor.")
|
||||
.call()
|
||||
.single(ActorsFilms.class);
|
||||
// @formatter:on
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography for a random actor.
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
|
||||
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms.getActor()).isNotBlank();
|
||||
}
|
||||
|
||||
record ActorsFilmsRecord(String actor, List<String> movies) {
|
||||
@@ -142,18 +154,13 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
@Test
|
||||
void beanOutputConverterRecords() {
|
||||
|
||||
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
|
||||
// @formatter:off
|
||||
ActorsFilmsRecord actorsFilms = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("Generate the filmography of 5 movies for Tom Hanks.")
|
||||
.call()
|
||||
.single(ActorsFilmsRecord.class);
|
||||
// @formatter:on
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography of 5 movies for Tom Hanks.
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
assertThat(actorsFilms.movies()).hasSize(5);
|
||||
@@ -164,25 +171,25 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
|
||||
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography of 5 movies for Tom Hanks.
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
// @formatter:off
|
||||
Flux<String> chatResponse = ChatClient.builder(chatModel)
|
||||
.build()
|
||||
.prompt()
|
||||
.user(u -> u
|
||||
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
|
||||
+ "{format}")
|
||||
.param("format", outputConverter.getFormat()))
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
String generationTextFromStream = streamingChatClient.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getContent)
|
||||
.collect(Collectors.joining());
|
||||
String generationTextFromStream = chatResponse.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.collect(Collectors.joining());
|
||||
// @formatter:on
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
|
||||
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
assertThat(actorsFilms.movies()).hasSize(5);
|
||||
@@ -191,54 +198,33 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
@Test
|
||||
void functionCallTest() {
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder()
|
||||
.withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
|
||||
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
|
||||
.withName("getCurrentWeather")
|
||||
.withDescription("Get the weather in location")
|
||||
.withResponseConverter((response) -> "" + response.temp() + response.unit())
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(chatModel).build().prompt()
|
||||
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
|
||||
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
|
||||
.call()
|
||||
.content();
|
||||
// @formatter:on
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
|
||||
assertThat(response).containsAnyOf("30.0", "30");
|
||||
assertThat(response).containsAnyOf("10.0", "10");
|
||||
assertThat(response).containsAnyOf("15.0", "15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void streamFunctionCallTest() {
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
// @formatter:off
|
||||
Flux<String> response = ChatClient.builder(chatModel).build().prompt()
|
||||
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
|
||||
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
|
||||
.stream()
|
||||
.content();
|
||||
// @formatter:on
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder()
|
||||
// .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
|
||||
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
|
||||
.withName("getCurrentWeather")
|
||||
.withDescription("Get the weather in location")
|
||||
.withResponseConverter((response) -> "" + response.temp() + response.unit())
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
Flux<ChatResponse> response = streamingChatClient.stream(new Prompt(messages, promptOptions));
|
||||
|
||||
String content = response.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getContent)
|
||||
.collect(Collectors.joining());
|
||||
String content = response.collectList().block().stream().collect(Collectors.joining());
|
||||
logger.info("Response: {}", content);
|
||||
|
||||
assertThat(content).containsAnyOf("30.0", "30");
|
||||
@@ -250,53 +236,62 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
|
||||
void multiModalityEmbeddedImage(String modelName) throws IOException {
|
||||
|
||||
var imageData = new ClassPathResource("/test.png");
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(chatModel).build().prompt()
|
||||
// TODO consider adding model(...) method to ChatClient as a shortcut to
|
||||
// OpenAiChatOptions.builder().withModel(modelName).build()
|
||||
.options(OpenAiChatOptions.builder().withModel(modelName).build())
|
||||
.user(u -> u.text("Explain what do you see on this picture?")
|
||||
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))
|
||||
.call()
|
||||
.content();
|
||||
// @formatter:on
|
||||
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?",
|
||||
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
|
||||
|
||||
var response = chatClient
|
||||
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
|
||||
|
||||
logger.info(response.getResult().getOutput().getContent());
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
|
||||
logger.info(response);
|
||||
assertThat(response).contains("bananas", "apple");
|
||||
assertThat(response).containsAnyOf("bowl", "basket");
|
||||
}
|
||||
|
||||
@ParameterizedTest(name = "{0} : {displayName} ")
|
||||
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
|
||||
void multiModalityImageUrl(String modelName) throws IOException {
|
||||
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?", List
|
||||
.of(new Media(MimeTypeUtils.IMAGE_PNG,
|
||||
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
|
||||
// TODO: add url method that wrapps the checked exception.
|
||||
URL url = new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png");
|
||||
|
||||
ChatResponse response = chatClient
|
||||
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
|
||||
// @formatter:off
|
||||
String response = ChatClient.builder(chatModel).build().prompt()
|
||||
// TODO consider adding model(...) method to ChatClient as a shortcut to
|
||||
// OpenAiChatOptions.builder().withModel(modelName).build()
|
||||
.options(OpenAiChatOptions.builder().withModel(modelName).build())
|
||||
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
|
||||
.call()
|
||||
.content();
|
||||
// @formatter:on
|
||||
|
||||
logger.info(response.getResult().getOutput().getContent());
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
|
||||
logger.info(response);
|
||||
assertThat(response).contains("bananas", "apple");
|
||||
assertThat(response).containsAnyOf("bowl", "basket");
|
||||
}
|
||||
|
||||
@Test
|
||||
void streamingMultiModalityImageUrl() throws IOException {
|
||||
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?", List
|
||||
.of(new Media(MimeTypeUtils.IMAGE_PNG,
|
||||
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
|
||||
// TODO: add url method that wrapps the checked exception.
|
||||
URL url = new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png");
|
||||
|
||||
Flux<ChatResponse> response = streamingChatClient.stream(new Prompt(List.of(userMessage),
|
||||
OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build()));
|
||||
// @formatter:off
|
||||
Flux<String> response = ChatClient.builder(chatModel).build().prompt()
|
||||
.options(OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue())
|
||||
.build())
|
||||
.user(u -> u.text("Explain what do you see on this picture?")
|
||||
.media(MimeTypeUtils.IMAGE_PNG, url))
|
||||
.stream()
|
||||
.content();
|
||||
// @formatter:on
|
||||
|
||||
String content = response.collectList().block().stream().collect(Collectors.joining());
|
||||
|
||||
String content = response.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getContent)
|
||||
.collect(Collectors.joining());
|
||||
logger.info("Response: {}", content);
|
||||
assertThat(content).contains("bananas", "apple");
|
||||
assertThat(content).containsAnyOf("bowl", "basket");
|
||||
|
||||
@@ -27,7 +27,7 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatClient;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
|
||||
@@ -41,14 +41,14 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
@SpringBootTest(classes = OpenAiChatClient2IT.Config.class)
|
||||
@SpringBootTest(classes = OpenAiChatModel2IT.Config.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
public class OpenAiChatClient2IT {
|
||||
public class OpenAiChatModel2IT {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(getClass());
|
||||
|
||||
@Autowired
|
||||
private OpenAiChatClient openAiChatClient;
|
||||
private OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Test
|
||||
void responseFormatTest() throws JsonMappingException, JsonProcessingException {
|
||||
@@ -67,7 +67,7 @@ public class OpenAiChatClient2IT {
|
||||
.withResponseFormat(new ChatCompletionRequest.ResponseFormat("json_object"))
|
||||
.build());
|
||||
|
||||
ChatResponse response = this.openAiChatClient.call(prompt);
|
||||
ChatResponse response = this.openAiChatModel.call(prompt);
|
||||
|
||||
assertThat(response).isNotNull();
|
||||
|
||||
@@ -99,8 +99,8 @@ public class OpenAiChatClient2IT {
|
||||
}
|
||||
|
||||
@Bean
|
||||
public OpenAiChatClient openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatClient(openAiApi);
|
||||
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
|
||||
return new OpenAiChatModel(openAiApi);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,305 @@
|
||||
/*
|
||||
* Copyright 2023 - 2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.openai.chat;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.Generation;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.converter.MapOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallbackWrapper;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.OpenAiTestConfiguration;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.openai.testutils.AbstractIT;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.convert.support.DefaultConversionService;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = OpenAiTestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
class OpenAiChatModelIT extends AbstractIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class);
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
|
||||
@Test
|
||||
void roleTest() {
|
||||
UserMessage userMessage = new UserMessage(
|
||||
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
ChatResponse response = chatModel.call(prompt);
|
||||
assertThat(response.getResults()).hasSize(1);
|
||||
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
|
||||
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
void listOutputConverter() {
|
||||
DefaultConversionService conversionService = new DefaultConversionService();
|
||||
ListOutputConverter outputConverter = new ListOutputConverter(conversionService);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
List five {subject}
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "ice cream flavors", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = this.chatModel.call(prompt).getResult();
|
||||
|
||||
List<String> list = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(list).hasSize(5);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void mapOutputConverter() {
|
||||
MapOutputConverter outputConverter = new MapOutputConverter();
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Provide me a List of {subject}
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template,
|
||||
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
|
||||
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanOutputConverter() {
|
||||
|
||||
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography for a random actor.
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
}
|
||||
|
||||
record ActorsFilmsRecord(String actor, List<String> movies) {
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanOutputConverterRecords() {
|
||||
|
||||
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography of 5 movies for Tom Hanks.
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
assertThat(actorsFilms.movies()).hasSize(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
void beanStreamOutputConverterRecords() {
|
||||
|
||||
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
|
||||
|
||||
String format = outputConverter.getFormat();
|
||||
String template = """
|
||||
Generate the filmography of 5 movies for Tom Hanks.
|
||||
{format}
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = streamingChatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getContent)
|
||||
.collect(Collectors.joining());
|
||||
|
||||
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
|
||||
logger.info("" + actorsFilms);
|
||||
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
|
||||
assertThat(actorsFilms.movies()).hasSize(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
void functionCallTest() {
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder()
|
||||
.withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
|
||||
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
|
||||
.withName("getCurrentWeather")
|
||||
.withDescription("Get the weather in location")
|
||||
.withResponseConverter((response) -> "" + response.temp() + response.unit())
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
|
||||
|
||||
logger.info("Response: {}", response);
|
||||
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void streamFunctionCallTest() {
|
||||
|
||||
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
|
||||
|
||||
List<Message> messages = new ArrayList<>(List.of(userMessage));
|
||||
|
||||
var promptOptions = OpenAiChatOptions.builder()
|
||||
// .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
|
||||
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
|
||||
.withName("getCurrentWeather")
|
||||
.withDescription("Get the weather in location")
|
||||
.withResponseConverter((response) -> "" + response.temp() + response.unit())
|
||||
.build()))
|
||||
.build();
|
||||
|
||||
Flux<ChatResponse> response = streamingChatModel.stream(new Prompt(messages, promptOptions));
|
||||
|
||||
String content = response.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getContent)
|
||||
.collect(Collectors.joining());
|
||||
logger.info("Response: {}", content);
|
||||
|
||||
assertThat(content).containsAnyOf("30.0", "30");
|
||||
assertThat(content).containsAnyOf("10.0", "10");
|
||||
assertThat(content).containsAnyOf("15.0", "15");
|
||||
}
|
||||
|
||||
@ParameterizedTest(name = "{0} : {displayName} ")
|
||||
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
|
||||
void multiModalityEmbeddedImage(String modelName) throws IOException {
|
||||
|
||||
var imageData = new ClassPathResource("/test.png");
|
||||
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?",
|
||||
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
|
||||
|
||||
var response = chatModel
|
||||
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
|
||||
|
||||
logger.info(response.getResult().getOutput().getContent());
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
|
||||
}
|
||||
|
||||
@ParameterizedTest(name = "{0} : {displayName} ")
|
||||
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
|
||||
void multiModalityImageUrl(String modelName) throws IOException {
|
||||
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?", List
|
||||
.of(new Media(MimeTypeUtils.IMAGE_PNG,
|
||||
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
|
||||
|
||||
ChatResponse response = chatModel
|
||||
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
|
||||
|
||||
logger.info(response.getResult().getOutput().getContent());
|
||||
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
|
||||
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
|
||||
}
|
||||
|
||||
@Test
|
||||
void streamingMultiModalityImageUrl() throws IOException {
|
||||
|
||||
var userMessage = new UserMessage("Explain what do you see on this picture?", List
|
||||
.of(new Media(MimeTypeUtils.IMAGE_PNG,
|
||||
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
|
||||
|
||||
Flux<ChatResponse> response = streamingChatModel.stream(new Prompt(List.of(userMessage),
|
||||
OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build()));
|
||||
|
||||
String content = response.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
.map(ChatResponse::getResults)
|
||||
.flatMap(List::stream)
|
||||
.map(Generation::getOutput)
|
||||
.map(AssistantMessage::getContent)
|
||||
.collect(Collectors.joining());
|
||||
logger.info("Response: {}", content);
|
||||
assertThat(content).contains("bananas", "apple");
|
||||
assertThat(content).containsAnyOf("bowl", "basket");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -39,10 +39,10 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest(classes = OpenAiTestConfiguration.class)
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
|
||||
class OpenAiChatClientTypeReferenceBeanOutputConverterIT extends AbstractIT {
|
||||
class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT {
|
||||
|
||||
private static final Logger logger = LoggerFactory
|
||||
.getLogger(OpenAiChatClientTypeReferenceBeanOutputConverterIT.class);
|
||||
.getLogger(OpenAiChatModelTypeReferenceBeanOutputConverterIT.class);
|
||||
|
||||
record ActorsFilmsRecord(String actor, List<String> movies) {
|
||||
}
|
||||
@@ -61,7 +61,7 @@ class OpenAiChatClientTypeReferenceBeanOutputConverterIT extends AbstractIT {
|
||||
""";
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
Generation generation = chatClient.call(prompt).getResult();
|
||||
Generation generation = chatModel.call(prompt).getResult();
|
||||
|
||||
List<ActorsFilmsRecord> actorsFilms = outputConverter.convert(generation.getOutput().getContent());
|
||||
logger.info("" + actorsFilms);
|
||||
@@ -87,7 +87,7 @@ class OpenAiChatClientTypeReferenceBeanOutputConverterIT extends AbstractIT {
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
|
||||
Prompt prompt = new Prompt(promptTemplate.createMessage());
|
||||
|
||||
String generationTextFromStream = streamingChatClient.stream(prompt)
|
||||
String generationTextFromStream = streamingChatModel.stream(prompt)
|
||||
.collectList()
|
||||
.block()
|
||||
.stream()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user