ChatClinet Flunet DSL and renaming [Model/Chat/Embedding/Image/Speech]Client to Model, [Chat/Image/Embedding/Image/Sppec]Model

- copy constructor for ChatClient#call incantations based on defaults.
 - copy constructor for ChatClient#call incantations based on defaults.
 - fix ChatClient structured output flow.
 - add OpenAiChatClientIT.
 - add ChatClient stream support
 - implement fromOptions copty factory for eavery chatoptions implementation.
 - extend ChatClient to use the caller default options if not provided explicitely.
 - fix system/user text overdidign default system/user texts. Only non empty
   user/system text can overrid defult system/user text.
 - rename chat() method to collect().
 - add OpenAI FunctionCallbackWrapper2IT auto-config tests.
 - rename request call() to prompt() and the response collect() to call(). Adjust tests
 - add overload methods for defaultSytem()/defaultUer() and system() methods.
 - add ChatClientTest mockito testing
 - fix ChatClient list() convertions
 - Rename the ModelClient class hiearchy into Model
  - Rename ModelClient into Model. Update all code and doc references.
  - Rename ChatClient to ChatModel. Update all ChatClient suffixes and chatClient fields and variables in code and doc.
  - Rename EmbeddingClient into EmbeddingModel. Update the XxxEmbeddingClient class and variable suffixes and embeddingClient variables and fields in code and docs.
  - Rename ImageClient into ImageModel. ....
  - Rename SpeechClient into SpeechModel ....
  - Rename TranscriptionClient into TranscriptionModel ...
  - Update all javadocs and antora pages. Update the related diagrams.
 - refactor usability tests.

Co-authored-by:  Christian Tzolov <ctzolov@vmware.com>
This commit is contained in:
Josh Long
2024-05-18 21:30:51 +02:00
committed by Christian Tzolov
parent bce45c2d2f
commit 57615b6303
419 changed files with 4591 additions and 2925 deletions

View File

@@ -26,6 +26,7 @@ import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatModel;
import reactor.core.publisher.Flux;
import org.springframework.ai.anthropic.api.AnthropicApi;
@@ -38,10 +39,9 @@ import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.api.AnthropicApi.StreamResponse;
import org.springframework.ai.anthropic.api.AnthropicApi.Usage;
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
@@ -56,16 +56,16 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* The {@link ChatClient} implementation for the Anthropic service.
* The {@link ChatModel} implementation for the Anthropic service.
*
* @author Christian Tzolov
* @since 1.0.0
*/
public class AnthropicChatClient extends
public class AnthropicChatModel extends
AbstractFunctionCallSupport<AnthropicApi.RequestMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletion>>
implements ChatClient, StreamingChatClient {
implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatClient.class);
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue();
@@ -89,10 +89,10 @@ public class AnthropicChatClient extends
public final RetryTemplate retryTemplate;
/**
* Construct a new {@link AnthropicChatClient} instance.
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
*/
public AnthropicChatClient(AnthropicApi anthropicApi) {
public AnthropicChatModel(AnthropicApi anthropicApi) {
this(anthropicApi,
AnthropicChatOptions.builder()
.withModel(DEFAULT_MODEL_NAME)
@@ -102,34 +102,34 @@ public class AnthropicChatClient extends
}
/**
* Construct a new {@link AnthropicChatClient} instance.
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
*/
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) {
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) {
this(anthropicApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Construct a new {@link AnthropicChatClient} instance.
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
*/
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate) {
this(anthropicApi, defaultOptions, retryTemplate, null);
}
/**
* Construct a new {@link AnthropicChatClient} instance.
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackContext the function callback context used to store the
* state of the function calls.
*/
public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
@@ -457,4 +457,9 @@ public class AnthropicChatClient extends
"Streaming (stream=true) is not yet supported. We plan to add streaming support in a future beta version.");
}
@Override
public ChatOptions getDefaultOptions() {
return AnthropicChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -51,11 +51,11 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
private @JsonProperty("top_k") Integer topK;
/**
* Tool Function Callbacks to register with the ChatClient. For Prompt
* Tool Function Callbacks to register with the ChatModel. For Prompt
* Options the functionCallbacks are automatically enabled for the duration of the
* prompt execution. For Default Options the functionCallbacks are registered but
* disabled by default. Use the enableFunctions to set the functions from the registry
* to be used by the ChatClient chat completion requests.
* to be used by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
@@ -223,4 +223,17 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
this.functions = functions;
}
public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) {
return builder().withModel(fromOptions.getModel())
.withMaxTokens(fromOptions.getMaxTokens())
.withMetadata(fromOptions.getMetadata())
.withStopSequences(fromOptions.getStopSequences())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTopK(fromOptions.getTopK())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}
}

View File

@@ -26,6 +26,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
@@ -116,7 +117,7 @@ public class AnthropicApi {
* "https://docs.anthropic.com/claude/docs/models-overview#model-comparison">model
* comparison</a> for additional details and options.
*/
public enum ChatModel {
public enum ChatModel implements ModelDescription {
// @formatter:off
CLAUDE_3_OPUS("claude-3-opus-20240229"),
@@ -140,6 +141,11 @@ public class AnthropicApi {
return this.value;
}
@Override
public String getModelName() {
return this.value;
}
}
/**

View File

@@ -29,10 +29,10 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.tool.MockWeatherService;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
@@ -56,15 +56,15 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = AnthropicTestConfiguration.class, properties = "spring.ai.retry.on-http-codes=429")
@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+")
class AnthropicChatClientIT {
class AnthropicChatModelIT {
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatClientIT.class);
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModelIT.class);
@Autowired
protected ChatClient chatClient;
protected ChatModel chatModel;
@Autowired
protected StreamingChatClient streamingChatClient;
protected StreamingChatModel streamingChatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -76,7 +76,7 @@ class AnthropicChatClientIT {
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatClient.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getMetadata().getUsage().getGenerationTokens()).isGreaterThan(0);
assertThat(response.getMetadata().getUsage().getPromptTokens()).isGreaterThan(0);
@@ -102,7 +102,7 @@ class AnthropicChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.chatClient.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = listOutputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -120,7 +120,7 @@ class AnthropicChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = mapOutputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -142,7 +142,7 @@ class AnthropicChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = beanOutputConverter.convert(generation.getOutput().getContent());
logger.info("" + actorsFilms);
@@ -163,7 +163,7 @@ class AnthropicChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = streamingChatClient.stream(prompt)
String generationTextFromStream = streamingChatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -187,7 +187,7 @@ class AnthropicChatClientIT {
var userMessage = new UserMessage("Explain what do you see on this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
var response = chatClient.call(new Prompt(List.of(userMessage)));
var response = chatModel.call(new Prompt(List.of(userMessage)));
logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
@@ -209,7 +209,7 @@ class AnthropicChatClientIT {
.build()))
.build();
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
logger.info("Response: {}", response);

View File

@@ -38,9 +38,9 @@ public class AnthropicTestConfiguration {
}
@Bean
public AnthropicChatClient openAiChatClient(AnthropicApi api) {
AnthropicChatClient anthropicChatClient = new AnthropicChatClient(api);
return anthropicChatClient;
public AnthropicChatModel openAiChatModel(AnthropicApi api) {
AnthropicChatModel anthropicChatModel = new AnthropicChatModel(api);
return anthropicChatModel;
}
}

View File

@@ -30,7 +30,7 @@ public class ChatCompletionRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new AnthropicChatClient(new AnthropicApi("TEST"),
var client = new AnthropicChatModel(new AnthropicApi("TEST"),
AnthropicChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
var request = client.createRequest(new Prompt("Test message content"), false);

View File

@@ -38,10 +38,10 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
@@ -63,7 +63,7 @@ import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
* {@link ChatModel} implementation for {@literal Microsoft Azure AI} backed by
* {@link OpenAIClient}.
*
* @author Mark Pollack
@@ -71,12 +71,12 @@ import java.util.concurrent.atomic.AtomicBoolean;
* @author John Blum
* @author Christian Tzolov
* @author Grogdunn
* @see ChatClient
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
*/
public class AzureOpenAiChatClient
public class AzureOpenAiChatModel
extends AbstractFunctionCallSupport<ChatRequestMessage, ChatCompletionsOptions, ChatCompletions>
implements ChatClient, StreamingChatClient {
implements ChatModel, StreamingChatModel {
private static final String DEFAULT_DEPLOYMENT_NAME = "gpt-35-turbo";
@@ -94,7 +94,7 @@ public class AzureOpenAiChatClient
*/
private final OpenAIClient openAIClient;
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient) {
this(microsoftOpenAiClient,
AzureOpenAiChatOptions.builder()
.withDeploymentName(DEFAULT_DEPLOYMENT_NAME)
@@ -102,11 +102,11 @@ public class AzureOpenAiChatClient
.build());
}
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options) {
this(microsoftOpenAiClient, options, null);
}
public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
public AzureOpenAiChatModel(OpenAIClient microsoftOpenAiClient, AzureOpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
Assert.notNull(microsoftOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
@@ -117,17 +117,17 @@ public class AzureOpenAiChatClient
/**
* @deprecated since 0.8.0, use
* {@link #AzureOpenAiChatClient(OpenAIClient, AzureOpenAiChatOptions)} instead.
* {@link #AzureOpenAiChatModel(OpenAIClient, AzureOpenAiChatOptions)} instead.
*/
@Deprecated(forRemoval = true, since = "0.8.0")
public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
public AzureOpenAiChatModel withDefaultOptions(AzureOpenAiChatOptions defaultOptions) {
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.defaultOptions = defaultOptions;
return this;
}
public AzureOpenAiChatOptions getDefaultOptions() {
return this.defaultOptions;
return AzureOpenAiChatOptions.fromOptions(this.defaultOptions);
}
@Override

View File

@@ -127,11 +127,11 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
private String deploymentName;
/**
* OpenAI Tool Function Callbacks to register with the ChatClient. For Prompt Options
* OpenAI Tool Function Callbacks to register with the ChatModel. For Prompt Options
* the functionCallbacks are automatically enabled for the duration of the prompt
* execution. For Default Options the functionCallbacks are registered but disabled by
* default. Use the enableFunctions to set the functions from the registry to be used
* by the ChatClient chat completion requests.
* by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
@@ -356,4 +356,22 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
this.functions = functions;
}
public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) {
return builder().withDeploymentName(fromOptions.getDeploymentName())
.withFrequencyPenalty(
fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty().floatValue() : null)
.withLogitBias(fromOptions.getLogitBias())
.withMaxTokens(fromOptions.getMaxTokens())
.withN(fromOptions.getN())
.withPresencePenalty(
fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty().floatValue() : null)
.withStop(fromOptions.getStop())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withUser(fromOptions.getUser())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}
}

View File

@@ -28,7 +28,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
@@ -37,9 +37,9 @@ import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingClient.class);
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class);
private final OpenAIClient azureOpenAiClient;
@@ -47,16 +47,16 @@ public class AzureOpenAiEmbeddingClient extends AbstractEmbeddingClient {
private final MetadataMode metadataMode;
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient) {
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient) {
this(azureOpenAiClient, MetadataMode.EMBED);
}
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) {
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode) {
this(azureOpenAiClient, metadataMode,
AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build());
}
public AzureOpenAiEmbeddingClient(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
AzureOpenAiEmbeddingOptions options) {
Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(metadataMode, "Metadata mode must not be null");

View File

@@ -53,7 +53,7 @@ public class AzureChatCompletionsOptionsTests {
.withUser("user")
.build();
var client = new AzureOpenAiChatClient(mockClient, defaultOptions);
var client = new AzureOpenAiChatModel(mockClient, defaultOptions);
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));

View File

@@ -36,7 +36,7 @@ public class AzureEmbeddingsOptionsTests {
public void createRequestWithChatOptions() {
OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
var client = new AzureOpenAiEmbeddingClient(mockClient, MetadataMode.EMBED,
var client = new AzureOpenAiEmbeddingModel(mockClient, MetadataMode.EMBED,
AzureOpenAiEmbeddingOptions.builder()
.withDeploymentName("DEFAULT_MODEL")
.withUser("USER_TEST")

View File

@@ -46,13 +46,13 @@ import org.springframework.core.convert.support.DefaultConversionService;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = AzureOpenAiChatClientIT.TestConfiguration.class)
@SpringBootTest(classes = AzureOpenAiChatModelIT.TestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
class AzureOpenAiChatClientIT {
class AzureOpenAiChatModelIT {
@Autowired
private AzureOpenAiChatClient chatClient;
private AzureOpenAiChatModel chatModel;
record ActorsFilms(String actor, List<String> movies) {
}
@@ -69,7 +69,7 @@ class AzureOpenAiChatClientIT {
UserMessage userMessage = new UserMessage("Generate the names of 5 famous pirates.");
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatClient.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -86,7 +86,7 @@ class AzureOpenAiChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -105,7 +105,7 @@ class AzureOpenAiChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -124,7 +124,7 @@ class AzureOpenAiChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isNotNull();
@@ -145,7 +145,7 @@ class AzureOpenAiChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
System.out.println(actorsFilms);
@@ -166,7 +166,7 @@ class AzureOpenAiChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = chatClient.stream(prompt)
String generationTextFromStream = chatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -194,8 +194,8 @@ class AzureOpenAiChatClientIT {
}
@Bean
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
return new AzureOpenAiChatClient(openAIClient,
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient) {
return new AzureOpenAiChatModel(openAIClient,
AzureOpenAiChatOptions.builder().withDeploymentName("gpt-35-turbo").withMaxTokens(200).build());
}

View File

@@ -34,25 +34,25 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
class AzureOpenAiEmbeddingClientIT {
class AzureOpenAiEmbeddingModelIT {
@Autowired
private AzureOpenAiEmbeddingClient embeddingClient;
private AzureOpenAiEmbeddingModel embeddingModel;
@Test
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
assertThat(embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
System.out.println(embeddingClient.dimensions());
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
System.out.println(embeddingModel.dimensions());
assertThat(embeddingModel.dimensions()).isEqualTo(1536);
}
@Test
void batchEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient
assertThat(embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = embeddingModel
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
@@ -60,7 +60,7 @@ class AzureOpenAiEmbeddingClientIT {
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingClient.dimensions()).isEqualTo(1536);
assertThat(embeddingModel.dimensions()).isEqualTo(1536);
}
@SpringBootConfiguration
@@ -74,8 +74,8 @@ class AzureOpenAiEmbeddingClientIT {
}
@Bean
public AzureOpenAiEmbeddingClient azureEmbeddingClient(OpenAIClient openAIClient) {
return new AzureOpenAiEmbeddingClient(openAIClient);
public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient) {
return new AzureOpenAiEmbeddingModel(openAIClient);
}
}

View File

@@ -59,8 +59,8 @@ public class MockAzureOpenAiTestConfiguration {
}
@Bean
AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient microsoftAzureOpenAiClient) {
return new AzureOpenAiChatClient(microsoftAzureOpenAiClient);
AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient microsoftAzureOpenAiClient) {
return new AzureOpenAiChatModel(microsoftAzureOpenAiClient);
}
}

View File

@@ -29,7 +29,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiChatClient;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
@@ -46,18 +46,18 @@ import reactor.core.publisher.Flux;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = AzureOpenAiChatClientFunctionCallIT.TestConfiguration.class)
@SpringBootTest(classes = AzureOpenAiChatModelFunctionCallIT.TestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
class AzureOpenAiChatClientFunctionCallIT {
class AzureOpenAiChatModelFunctionCallIT {
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatClientFunctionCallIT.class);
private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModelFunctionCallIT.class);
@Autowired
private String selectedModel;
@Autowired
private AzureOpenAiChatClient chatClient;
private AzureOpenAiChatModel chatModel;
@Test
void functionCallTest() {
@@ -75,7 +75,7 @@ class AzureOpenAiChatClientFunctionCallIT {
.build()))
.build();
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
logger.info("Response: {}", response);
@@ -99,7 +99,7 @@ class AzureOpenAiChatClientFunctionCallIT {
.build()))
.build();
Flux<ChatResponse> response = chatClient.stream(new Prompt(messages, promptOptions));
Flux<ChatResponse> response = chatModel.stream(new Prompt(messages, promptOptions));
final var counter = new AtomicInteger();
String content = response.doOnEach(listSignal -> counter.getAndIncrement())
@@ -129,8 +129,8 @@ class AzureOpenAiChatClientFunctionCallIT {
}
@Bean
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient, String selectedModel) {
return new AzureOpenAiChatClient(openAIClient,
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClient openAIClient, String selectedModel) {
return new AzureOpenAiChatModel(openAIClient,
AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build());
}

View File

@@ -23,7 +23,7 @@ import com.azure.ai.openai.models.ContentFilterResultsForChoice;
import com.azure.ai.openai.models.ContentFilterSeverity;
import org.junit.jupiter.api.Test;
import org.springframework.ai.azure.openai.AzureOpenAiChatClient;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.MockAzureOpenAiTestConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
@@ -55,7 +55,7 @@ import org.springframework.web.context.request.WebRequest;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Unit Tests for {@link AzureOpenAiChatClient} asserting AI metadata.
* Unit Tests for {@link AzureOpenAiChatModel} asserting AI metadata.
*
* @author John Blum
* @author Christian Tzolov
@@ -63,12 +63,12 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
@SpringBootTest
@ActiveProfiles("spring-ai-azure-openai-mocks")
@ContextConfiguration(classes = AzureOpenAiChatClientMetadataTests.TestConfiguration.class)
@ContextConfiguration(classes = AzureOpenAiChatModelMetadataTests.TestConfiguration.class)
@SuppressWarnings("unused")
class AzureOpenAiChatClientMetadataTests {
class AzureOpenAiChatModelMetadataTests {
@Autowired
private AzureOpenAiChatClient aiClient;
private AzureOpenAiChatModel aiClient;
@Test
void azureOpenAiMetadataCapturedDuringGeneration() {

View File

@@ -164,4 +164,14 @@ public class AnthropicChatOptions implements ChatOptions {
this.anthropicVersion = anthropicVersion;
}
public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withMaxTokensToSample(fromOptions.getMaxTokensToSample())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withStopSequences(fromOptions.getStopSequences())
.withAnthropicVersion(fromOptions.getAnthropicVersion())
.build();
}
}

View File

@@ -17,7 +17,7 @@ package org.springframework.ai.bedrock.anthropic;
import java.util.List;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
@@ -27,25 +27,25 @@ import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
/**
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat
* generative.
*
* @author Christian Tzolov
* @since 0.8.0
*/
public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClient {
public class BedrockAnthropicChatModel implements ChatModel, StreamingChatModel {
private final AnthropicChatBedrockApi anthropicChatApi;
private final AnthropicChatOptions defaultOptions;
public BedrockAnthropicChatClient(AnthropicChatBedrockApi chatApi) {
public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi) {
this(chatApi,
AnthropicChatOptions.builder()
.withTemperature(0.8f)
@@ -55,7 +55,7 @@ public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClie
.build());
}
public BedrockAnthropicChatClient(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options) {
public BedrockAnthropicChatModel(AnthropicChatBedrockApi chatApi, AnthropicChatOptions options) {
this.anthropicChatApi = chatApi;
this.defaultOptions = options;
}
@@ -117,4 +117,9 @@ public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClie
return request;
}
@Override
public ChatOptions getDefaultOptions() {
return AnthropicChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -29,6 +29,7 @@ import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.model.ModelDescription;
import org.springframework.util.Assert;
/**
@@ -225,7 +226,7 @@ public class AnthropicChatBedrockApi extends
/**
* Anthropic models version.
*/
public enum AnthropicChatModel {
public enum AnthropicChatModel implements ModelDescription {
/**
* anthropic.claude-instant-v1
*/
@@ -251,6 +252,11 @@ public class AnthropicChatBedrockApi extends
AnthropicChatModel(String value) {
this.id = value;
}
@Override
public String getModelName() {
return this.id;
}
}
@Override

View File

@@ -163,4 +163,14 @@ public class Anthropic3ChatOptions implements ChatOptions {
this.anthropicVersion = anthropicVersion;
}
public static Anthropic3ChatOptions fromOptions(Anthropic3ChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withMaxTokens(fromOptions.getMaxTokens())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withStopSequences(fromOptions.getStopSequences())
.withAnthropicVersion(fromOptions.getAnthropicVersion())
.build();
}
}

View File

@@ -15,17 +15,25 @@
*/
package org.springframework.ai.bedrock.anthropic3;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import reactor.core.publisher.Flux;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatRequest;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
@@ -34,29 +42,21 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
/**
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Anthropic chat
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Anthropic chat
* generative.
*
* @author Ben Middleton
* @author Christian Tzolov
* @since 1.0.0
*/
public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatClient {
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
private final Anthropic3ChatBedrockApi anthropicChatApi;
private final Anthropic3ChatOptions defaultOptions;
public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi) {
public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi) {
this(chatApi,
Anthropic3ChatOptions.builder()
.withTemperature(0.8f)
@@ -66,7 +66,7 @@ public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatCli
.build());
}
public BedrockAnthropic3ChatClient(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) {
public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi chatApi, Anthropic3ChatOptions options) {
this.anthropicChatApi = chatApi;
this.defaultOptions = options;
}
@@ -187,4 +187,9 @@ public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatCli
}
}
@Override
public ChatOptions getDefaultOptions() {
return Anthropic3ChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -23,6 +23,7 @@ import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.An
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.model.ModelDescription;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
@@ -436,7 +437,7 @@ public class Anthropic3ChatBedrockApi extends
/**
* Anthropic models version.
*/
public enum AnthropicChatModel {
public enum AnthropicChatModel implements ModelDescription {
/**
* anthropic.claude-instant-v1
@@ -476,6 +477,11 @@ public class Anthropic3ChatBedrockApi extends
this.id = value;
}
@Override
public String getModelName() {
return this.id;
}
}
@Override

View File

@@ -24,11 +24,11 @@ import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -39,17 +39,17 @@ import org.springframework.util.Assert;
* @author Christian Tzolov
* @since 0.8.0
*/
public class BedrockCohereChatClient implements ChatClient, StreamingChatClient {
public class BedrockCohereChatModel implements ChatModel, StreamingChatModel {
private final CohereChatBedrockApi chatApi;
private final BedrockCohereChatOptions defaultOptions;
public BedrockCohereChatClient(CohereChatBedrockApi chatApi) {
public BedrockCohereChatModel(CohereChatBedrockApi chatApi) {
this(chatApi, BedrockCohereChatOptions.builder().build());
}
public BedrockCohereChatClient(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) {
public BedrockCohereChatModel(CohereChatBedrockApi chatApi, BedrockCohereChatOptions options) {
Assert.notNull(chatApi, "CohereChatBedrockApi must not be null");
Assert.notNull(options, "BedrockCohereChatOptions must not be null");
@@ -114,4 +114,9 @@ public class BedrockCohereChatClient implements ChatClient, StreamingChatClient
return request;
}
@Override
public ChatOptions getDefaultOptions() {
return BedrockCohereChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -213,4 +213,17 @@ public class BedrockCohereChatOptions implements ChatOptions {
this.truncate = truncate;
}
public static BedrockCohereChatOptions fromOptions(BedrockCohereChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTopK(fromOptions.getTopK())
.withMaxTokens(fromOptions.getMaxTokens())
.withStopSequences(fromOptions.getStopSequences())
.withReturnLikelihoods(fromOptions.getReturnLikelihoods())
.withNumGenerations(fromOptions.getNumGenerations())
.withLogitBias(fromOptions.getLogitBias())
.withTruncate(fromOptions.getTruncate())
.build();
}
}

View File

@@ -22,7 +22,7 @@ import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingRequest;
import org.springframework.ai.bedrock.cohere.api.CohereEmbeddingBedrockApi.CohereEmbeddingResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
@@ -31,14 +31,14 @@ import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
/**
* {@link org.springframework.ai.embedding.EmbeddingClient} implementation that uses the
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
* Bedrock Cohere Embedding API. Note: The invocation metrics are not exposed by AWS for
* this API. If this change in the future we will add it as metadata.
*
* @author Christian Tzolov
* @since 0.8.0
*/
public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
public class BedrockCohereEmbeddingModel extends AbstractEmbeddingModel {
private final CohereEmbeddingBedrockApi embeddingApi;
@@ -50,7 +50,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
// private CohereEmbeddingRequest.Truncate truncate =
// CohereEmbeddingRequest.Truncate.NONE;
public BedrockCohereEmbeddingClient(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi) {
public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi) {
this(cohereEmbeddingBedrockApi,
BedrockCohereEmbeddingOptions.builder()
.withInputType(CohereEmbeddingRequest.InputType.SEARCH_DOCUMENT)
@@ -58,7 +58,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
.build());
}
public BedrockCohereEmbeddingClient(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedrockApi,
BedrockCohereEmbeddingOptions options) {
Assert.notNull(cohereEmbeddingBedrockApi, "CohereEmbeddingBedrockApi must not be null");
Assert.notNull(options, "BedrockCohereEmbeddingOptions must not be null");
@@ -71,7 +71,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
// * @param inputType the input type to use.
// * @return this client.
// */
// public BedrockCohereEmbeddingClient withInputType(CohereEmbeddingRequest.InputType
// public BedrockCohereEmbeddingModel withInputType(CohereEmbeddingRequest.InputType
// inputType) {
// this.inputType = inputType;
// return this;
@@ -85,7 +85,7 @@ public class BedrockCohereEmbeddingClient extends AbstractEmbeddingClient {
// * @param truncate the truncate option to use.
// * @return this client.
// */
// public BedrockCohereEmbeddingClient withTruncate(CohereEmbeddingRequest.Truncate
// public BedrockCohereEmbeddingModel withTruncate(CohereEmbeddingRequest.Truncate
// truncate) {
// this.truncate = truncate;
// return this;

View File

@@ -30,6 +30,7 @@ import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatRequest;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi.CohereChatResponse;
import org.springframework.ai.model.ModelDescription;
import org.springframework.util.Assert;
/**
@@ -366,7 +367,7 @@ public class CohereChatBedrockApi extends
/**
* Cohere models version.
*/
public enum CohereChatModel {
public enum CohereChatModel implements ModelDescription {
/**
* cohere.command-light-text-v14
@@ -390,6 +391,11 @@ public class CohereChatBedrockApi extends
CohereChatModel(String value) {
this.id = value;
}
@Override
public String getModelName() {
return this.id;
}
}
@Override

View File

@@ -19,7 +19,7 @@ package org.springframework.ai.bedrock.jurassic2;
import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
@@ -29,19 +29,18 @@ import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
/**
* Java {@link ChatClient} for the Bedrock Jurassic2 chat generative model.
* Java {@link ChatModel} for the Bedrock Jurassic2 chat generative model.
*
* @author Ahmed Yousri
* @since 1.0.0
*/
public class BedrockAi21Jurassic2ChatClient implements ChatClient {
public class BedrockAi21Jurassic2ChatModel implements ChatModel {
private final Ai21Jurassic2ChatBedrockApi chatApi;
private final BedrockAi21Jurassic2ChatOptions defaultOptions;
public BedrockAi21Jurassic2ChatClient(Ai21Jurassic2ChatBedrockApi chatApi,
BedrockAi21Jurassic2ChatOptions options) {
public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi, BedrockAi21Jurassic2ChatOptions options) {
Assert.notNull(chatApi, "Ai21Jurassic2ChatBedrockApi must not be null");
Assert.notNull(options, "BedrockAi21Jurassic2ChatOptions must not be null");
@@ -49,7 +48,7 @@ public class BedrockAi21Jurassic2ChatClient implements ChatClient {
this.defaultOptions = options;
}
public BedrockAi21Jurassic2ChatClient(Ai21Jurassic2ChatBedrockApi chatApi) {
public BedrockAi21Jurassic2ChatModel(Ai21Jurassic2ChatBedrockApi chatApi) {
this(chatApi,
BedrockAi21Jurassic2ChatOptions.builder()
.withTemperature(0.8f)
@@ -114,11 +113,16 @@ public class BedrockAi21Jurassic2ChatClient implements ChatClient {
return this;
}
public BedrockAi21Jurassic2ChatClient build() {
return new BedrockAi21Jurassic2ChatClient(chatApi,
public BedrockAi21Jurassic2ChatModel build() {
return new BedrockAi21Jurassic2ChatModel(chatApi,
options != null ? options : BedrockAi21Jurassic2ChatOptions.builder().build());
}
}
@Override
public ChatOptions getDefaultOptions() {
return BedrockAi21Jurassic2ChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -413,4 +413,19 @@ public class BedrockAi21Jurassic2ChatOptions implements ChatOptions {
}
}
public static BedrockAi21Jurassic2ChatOptions fromOptions(BedrockAi21Jurassic2ChatOptions fromOptions) {
return builder().withPrompt(fromOptions.getPrompt())
.withNumResults(fromOptions.getNumResults())
.withMaxTokens(fromOptions.getMaxTokens())
.withMinTokens(fromOptions.getMinTokens())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTopK(fromOptions.getTopK())
.withStopSequences(fromOptions.getStopSequences())
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
.withPresencePenalty(fromOptions.getPresencePenalty())
.withCountPenalty(fromOptions.getCountPenalty())
.build();
}
}

View File

@@ -27,6 +27,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatRequest;
import org.springframework.ai.bedrock.jurassic2.api.Ai21Jurassic2ChatBedrockApi.Ai21Jurassic2ChatResponse;
import org.springframework.ai.model.ModelDescription;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
@@ -371,7 +372,7 @@ public class Ai21Jurassic2ChatBedrockApi extends
/**
* Ai21 Jurassic2 models version.
*/
public enum Ai21Jurassic2ChatModel {
public enum Ai21Jurassic2ChatModel implements ModelDescription {
/**
* ai21.j2-mid-v1
@@ -395,6 +396,11 @@ public class Ai21Jurassic2ChatBedrockApi extends
Ai21Jurassic2ChatModel(String value) {
this.id = value;
}
@Override
public String getModelName() {
return this.id;
}
}
@Override

View File

@@ -23,11 +23,11 @@ import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -35,25 +35,25 @@ import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
/**
* Java {@link ChatClient} and {@link StreamingChatClient} for the Bedrock Llama chat
* Java {@link ChatModel} and {@link StreamingChatModel} for the Bedrock Llama chat
* generative.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {
public class BedrockLlamaChatModel implements ChatModel, StreamingChatModel {
private final LlamaChatBedrockApi chatApi;
private final BedrockLlamaChatOptions defaultOptions;
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi) {
public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi) {
this(chatApi,
BedrockLlamaChatOptions.builder().withTemperature(0.8f).withTopP(0.9f).withMaxGenLen(100).build());
}
public BedrockLlamaChatClient(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
public BedrockLlamaChatModel(LlamaChatBedrockApi chatApi, BedrockLlamaChatOptions options) {
Assert.notNull(chatApi, "LlamaChatBedrockApi must not be null");
Assert.notNull(options, "BedrockLlamaChatOptions must not be null");
@@ -130,4 +130,9 @@ public class BedrockLlamaChatClient implements ChatClient, StreamingChatClient {
return request;
}
@Override
public ChatOptions getDefaultOptions() {
return BedrockLlamaChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -109,4 +109,11 @@ public class BedrockLlamaChatOptions implements ChatOptions {
throw new UnsupportedOperationException("Unsupported option: 'TopK'");
}
public static BedrockLlamaChatOptions fromOptions(BedrockLlamaChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withMaxGenLen(fromOptions.getMaxGenLen())
.build();
}
}

View File

@@ -26,6 +26,7 @@ import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.api.AbstractBedrockApi;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatRequest;
import org.springframework.ai.bedrock.llama.api.LlamaChatBedrockApi.LlamaChatResponse;
import org.springframework.ai.model.ModelDescription;
import java.time.Duration;
@@ -204,7 +205,7 @@ public class LlamaChatBedrockApi extends
/**
* Llama models version.
*/
public enum LlamaChatModel {
public enum LlamaChatModel implements ModelDescription {
/**
* meta.llama2-13b-chat-v1
@@ -238,6 +239,11 @@ public class LlamaChatBedrockApi extends
LlamaChatModel(String value) {
this.id = value;
}
@Override
public String getModelName() {
return this.id;
}
}
@Override

View File

@@ -24,11 +24,11 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatRequest;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -39,17 +39,17 @@ import org.springframework.util.Assert;
* @author Christian Tzolov
* @since 0.8.0
*/
public class BedrockTitanChatClient implements ChatClient, StreamingChatClient {
public class BedrockTitanChatModel implements ChatModel, StreamingChatModel {
private final TitanChatBedrockApi chatApi;
private final BedrockTitanChatOptions defaultOptions;
public BedrockTitanChatClient(TitanChatBedrockApi chatApi) {
public BedrockTitanChatModel(TitanChatBedrockApi chatApi) {
this(chatApi, BedrockTitanChatOptions.builder().withTemperature(0.8f).build());
}
public BedrockTitanChatClient(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) {
public BedrockTitanChatModel(TitanChatBedrockApi chatApi, BedrockTitanChatOptions defaultOptions) {
Assert.notNull(chatApi, "ChatApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.chatApi = chatApi;
@@ -146,4 +146,9 @@ public class BedrockTitanChatClient implements ChatClient, StreamingChatClient {
};
}
@Override
public ChatOptions getDefaultOptions() {
return BedrockTitanChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -128,4 +128,12 @@ public class BedrockTitanChatOptions implements ChatOptions {
throw new UnsupportedOperationException("Bedrock Titan Chat does not support the 'TopK' option.'");
}
public static BedrockTitanChatOptions fromOptions(BedrockTitanChatOptions fromOptions) {
return builder().withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withMaxTokenCount(fromOptions.getMaxTokenCount())
.withStopSequences(fromOptions.getStopSequences())
.build();
}
}

View File

@@ -26,7 +26,7 @@ import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
@@ -34,7 +34,7 @@ import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;
/**
* {@link org.springframework.ai.embedding.EmbeddingClient} implementation that uses the
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
* Bedrock Titan Embedding API. Titan Embedding supports text and image (encoded in
* base64) inputs.
*
@@ -44,7 +44,7 @@ import org.springframework.util.Assert;
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel {
private final Logger logger = LoggerFactory.getLogger(getClass());
@@ -61,7 +61,7 @@ public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
*/
private InputType inputType = InputType.TEXT;
public BedrockTitanEmbeddingClient(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
this.embeddingApi = titanEmbeddingBedrockApi;
}
@@ -69,7 +69,7 @@ public class BedrockTitanEmbeddingClient extends AbstractEmbeddingClient {
* Titan Embedding API input types. Could be either text or image (encoded in base64).
* @param inputType the input type to use.
*/
public BedrockTitanEmbeddingClient withInputType(InputType inputType) {
public BedrockTitanEmbeddingModel withInputType(InputType inputType) {
this.inputType = inputType;
return this;
}

View File

@@ -18,7 +18,7 @@ package org.springframework.ai.bedrock.titan;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.util.Assert;

View File

@@ -31,6 +31,7 @@ import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatReq
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponse.CompletionReason;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi.TitanChatResponseChunk;
import org.springframework.ai.model.ModelDescription;
/**
* Java client for the Bedrock Titan chat model.
@@ -265,7 +266,7 @@ public class TitanChatBedrockApi extends
/**
* Titan models version.
*/
public enum TitanChatModel {
public enum TitanChatModel implements ModelDescription {
/**
* amazon.titan-text-lite-v1
@@ -294,6 +295,11 @@ public class TitanChatBedrockApi extends
TitanChatModel(String value) {
this.id = value;
}
@Override
public String getModelName() {
return this.id;
}
}
@Override

View File

@@ -55,12 +55,12 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockAnthropicChatClientIT {
class BedrockAnthropicChatModelIT {
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropicChatClientIT.class);
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropicChatModelIT.class);
@Autowired
private BedrockAnthropicChatClient client;
private BedrockAnthropicChatModel chatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -68,8 +68,8 @@ class BedrockAnthropicChatClientIT {
@Test
void multipleStreamAttempts() {
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
String joke1 = joke1Stream.collectList()
.block()
@@ -101,7 +101,7 @@ class BedrockAnthropicChatClientIT {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -119,7 +119,7 @@ class BedrockAnthropicChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors.", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.client.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputParser.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -137,7 +137,7 @@ class BedrockAnthropicChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -161,7 +161,7 @@ class BedrockAnthropicChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConvert.convert(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
@@ -182,7 +182,7 @@ class BedrockAnthropicChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = client.stream(prompt)
String generationTextFromStream = chatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -209,8 +209,8 @@ class BedrockAnthropicChatClientIT {
}
@Bean
public BedrockAnthropicChatClient anthropicChatClient(AnthropicChatBedrockApi anthropicApi) {
return new BedrockAnthropicChatClient(anthropicApi);
public BedrockAnthropicChatModel anthropicChatModel(AnthropicChatBedrockApi anthropicApi) {
return new BedrockAnthropicChatModel(anthropicApi);
}
}

View File

@@ -38,7 +38,7 @@ public class BedrockAnthropicCreateRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new BedrockAnthropicChatClient(anthropicChatApi,
var client = new BedrockAnthropicChatModel(anthropicChatApi,
AnthropicChatOptions.builder()
.withTemperature(66.6f)
.withTopK(66)

View File

@@ -59,12 +59,12 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockAnthropic3ChatClientIT {
class BedrockAnthropic3ChatModelIT {
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropic3ChatClientIT.class);
private static final Logger logger = LoggerFactory.getLogger(BedrockAnthropic3ChatModelIT.class);
@Autowired
private BedrockAnthropic3ChatClient client;
private BedrockAnthropic3ChatModel chatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -72,8 +72,8 @@ class BedrockAnthropic3ChatClientIT {
@Test
void multipleStreamAttempts() {
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
String joke1 = joke1Stream.collectList()
.block()
@@ -105,7 +105,7 @@ class BedrockAnthropic3ChatClientIT {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -123,7 +123,7 @@ class BedrockAnthropic3ChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors.", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.client.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -142,7 +142,7 @@ class BedrockAnthropic3ChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -166,7 +166,7 @@ class BedrockAnthropic3ChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
@@ -187,7 +187,7 @@ class BedrockAnthropic3ChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = client.stream(prompt)
String generationTextFromStream = chatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -211,7 +211,7 @@ class BedrockAnthropic3ChatClientIT {
var userMessage = new UserMessage("Explain what do you see o this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
var response = client.call(new Prompt(List.of(userMessage)));
var response = chatModel.call(new Prompt(List.of(userMessage)));
logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
@@ -228,8 +228,8 @@ class BedrockAnthropic3ChatClientIT {
}
@Bean
public BedrockAnthropic3ChatClient anthropicChatClient(Anthropic3ChatBedrockApi anthropicApi) {
return new BedrockAnthropic3ChatClient(anthropicApi);
public BedrockAnthropic3ChatModel anthropicChatModel(Anthropic3ChatBedrockApi anthropicApi) {
return new BedrockAnthropic3ChatModel(anthropicApi);
}
}

View File

@@ -37,7 +37,7 @@ public class BedrockAnthropic3CreateRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new BedrockAnthropic3ChatClient(anthropicChatApi,
var client = new BedrockAnthropic3ChatModel(anthropicChatApi,
Anthropic3ChatOptions.builder()
.withTemperature(66.6f)
.withTopK(66)

View File

@@ -45,7 +45,7 @@ public class BedrockCohereChatCreateRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new BedrockCohereChatClient(chatApi,
var client = new BedrockCohereChatModel(chatApi,
BedrockCohereChatOptions.builder()
.withTemperature(66.6f)
.withTopK(66)

View File

@@ -54,10 +54,10 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockCohereChatClientIT {
class BedrockCohereChatModelIT {
@Autowired
private BedrockCohereChatClient client;
private BedrockCohereChatModel chatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -65,8 +65,8 @@ class BedrockCohereChatClientIT {
@Test
void multipleStreamAttempts() {
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
String joke1 = joke1Stream.collectList()
.block()
@@ -98,7 +98,7 @@ class BedrockCohereChatClientIT {
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -115,7 +115,7 @@ class BedrockCohereChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors.", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.client.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -134,7 +134,7 @@ class BedrockCohereChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -157,7 +157,7 @@ class BedrockCohereChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
@@ -178,7 +178,7 @@ class BedrockCohereChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = client.stream(prompt)
String generationTextFromStream = chatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -205,8 +205,8 @@ class BedrockCohereChatClientIT {
}
@Bean
public BedrockCohereChatClient cohereChatClient(CohereChatBedrockApi cohereApi) {
return new BedrockCohereChatClient(cohereApi);
public BedrockCohereChatModel cohereChatModel(CohereChatBedrockApi cohereApi) {
return new BedrockCohereChatModel(cohereApi);
}
}

View File

@@ -39,24 +39,24 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockCohereEmbeddingClientIT {
class BedrockCohereEmbeddingModelIT {
@Autowired
private BedrockCohereEmbeddingClient embeddingClient;
private BedrockCohereEmbeddingModel embeddingModel;
@Test
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
assertThat(embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
}
@Test
void batchEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient
assertThat(embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = embeddingModel
.embedForResponse(List.of("Hello World", "World is big and salvation is near"));
assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
@@ -64,13 +64,13 @@ class BedrockCohereEmbeddingClientIT {
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
}
@Test
void embeddingWthOptions() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient
assertThat(embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = embeddingModel
.call(new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
BedrockCohereEmbeddingOptions.builder().withInputType(InputType.SEARCH_DOCUMENT).build()));
assertThat(embeddingResponse.getResults()).hasSize(2);
@@ -79,7 +79,7 @@ class BedrockCohereEmbeddingClientIT {
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
}
@SpringBootConfiguration
@@ -93,8 +93,8 @@ class BedrockCohereEmbeddingClientIT {
}
@Bean
public BedrockCohereEmbeddingClient cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) {
return new BedrockCohereEmbeddingClient(cohereEmbeddingApi);
public BedrockCohereEmbeddingModel cohereAiEmbedding(CohereEmbeddingBedrockApi cohereEmbeddingApi) {
return new BedrockCohereEmbeddingModel(cohereEmbeddingApi);
}
}

View File

@@ -49,10 +49,10 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockAi21Jurassic2ChatClientIT {
class BedrockAi21Jurassic2ChatModelIT {
@Autowired
private BedrockAi21Jurassic2ChatClient client;
private BedrockAi21Jurassic2ChatModel chatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -66,7 +66,7 @@ class BedrockAi21Jurassic2ChatClientIT {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -83,7 +83,7 @@ class BedrockAi21Jurassic2ChatClientIT {
UserMessage userMessage = new UserMessage("Can you express happiness using an emoji like 😄 ?");
Prompt prompt = new Prompt(List.of(userMessage), options);
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).matches(content -> content.contains("😄"));
}
@@ -103,7 +103,7 @@ class BedrockAi21Jurassic2ChatClientIT {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), options);
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).doesNotContain("😄");
}
@@ -120,7 +120,7 @@ class BedrockAi21Jurassic2ChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -135,7 +135,7 @@ class BedrockAi21Jurassic2ChatClientIT {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("AI");
}
@@ -152,9 +152,9 @@ class BedrockAi21Jurassic2ChatClientIT {
}
@Bean
public BedrockAi21Jurassic2ChatClient bedrockAi21Jurassic2ChatClient(
public BedrockAi21Jurassic2ChatModel bedrockAi21Jurassic2ChatModel(
Ai21Jurassic2ChatBedrockApi jurassic2ChatBedrockApi) {
return new BedrockAi21Jurassic2ChatClient(jurassic2ChatBedrockApi,
return new BedrockAi21Jurassic2ChatModel(jurassic2ChatBedrockApi,
BedrockAi21Jurassic2ChatOptions.builder()
.withTemperature(0.5f)
.withMaxTokens(100)

View File

@@ -45,7 +45,7 @@ public class BedrockLlamaCreateRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new BedrockLlamaChatClient(api,
var client = new BedrockLlamaChatModel(api,
BedrockLlamaChatOptions.builder().withTemperature(66.6f).withMaxGenLen(666).withTopP(0.66f).build());
var request = client.createRequest(new Prompt("Test message content"));

View File

@@ -54,10 +54,10 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockLlamaChatClientIT {
class BedrockLlamaChatModelIT {
@Autowired
private BedrockLlamaChatClient client;
private BedrockLlamaChatModel chatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -65,8 +65,8 @@ class BedrockLlamaChatClientIT {
@Test
void multipleStreamAttempts() {
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a Toy joke?")));
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a Toy joke?")));
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
String joke1 = joke1Stream.collectList()
.block()
@@ -98,7 +98,7 @@ class BedrockLlamaChatClientIT {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -116,7 +116,7 @@ class BedrockLlamaChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors.", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.client.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -134,7 +134,7 @@ class BedrockLlamaChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -158,7 +158,7 @@ class BedrockLlamaChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
@@ -179,7 +179,7 @@ class BedrockLlamaChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = client.stream(prompt)
String generationTextFromStream = chatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -206,8 +206,8 @@ class BedrockLlamaChatClientIT {
}
@Bean
public BedrockLlamaChatClient llamaChatClient(LlamaChatBedrockApi llamaApi) {
return new BedrockLlamaChatClient(llamaApi,
public BedrockLlamaChatModel llamaChatModel(LlamaChatBedrockApi llamaApi) {
return new BedrockLlamaChatModel(llamaApi,
BedrockLlamaChatOptions.builder().withTemperature(0.5f).withMaxGenLen(100).withTopP(0.9f).build());
}

View File

@@ -41,7 +41,7 @@ public class BedrockTitanChatCreateRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new BedrockTitanChatClient(api,
var client = new BedrockTitanChatModel(api,
BedrockTitanChatOptions.builder()
.withTemperature(66.6f)
.withTopP(0.66f)

View File

@@ -26,7 +26,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingClient.InputType;
import org.springframework.ai.bedrock.titan.BedrockTitanEmbeddingModel.InputType;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingModel;
import org.springframework.ai.embedding.EmbeddingRequest;
@@ -44,19 +44,19 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockTitanEmbeddingClientIT {
class BedrockTitanEmbeddingModelIT {
@Autowired
private BedrockTitanEmbeddingClient embeddingClient;
private BedrockTitanEmbeddingModel embeddingModel;
@Test
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.call(new EmbeddingRequest(List.of("Hello World"),
assertThat(embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of("Hello World"),
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.TEXT).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
}
@Test
@@ -65,12 +65,12 @@ class BedrockTitanEmbeddingClientIT {
byte[] image = new DefaultResourceLoader().getResource("classpath:/spring_framework.png")
.getContentAsByteArray();
EmbeddingResponse embeddingResponse = embeddingClient
EmbeddingResponse embeddingResponse = embeddingModel
.call(new EmbeddingRequest(List.of(Base64.getEncoder().encodeToString(image)),
BedrockTitanEmbeddingOptions.builder().withInputType(InputType.IMAGE).build()));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(1024);
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
}
@SpringBootConfiguration
@@ -84,8 +84,8 @@ class BedrockTitanEmbeddingClientIT {
}
@Bean
public BedrockTitanEmbeddingClient titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi) {
return new BedrockTitanEmbeddingClient(titanEmbeddingApi);
public BedrockTitanEmbeddingModel titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi) {
return new BedrockTitanEmbeddingModel(titanEmbeddingApi);
}
}

View File

@@ -55,10 +55,10 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@EnabledIfEnvironmentVariable(named = "AWS_SECRET_ACCESS_KEY", matches = ".*")
class BedrockTitanChatClientIT {
class BedrockTitanModelCalerlIT {
@Autowired
private BedrockTitanChatClient client;
private BedrockTitanChatModel chatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -66,8 +66,8 @@ class BedrockTitanChatClientIT {
@Test
void multipleStreamAttempts() {
Flux<ChatResponse> joke1Stream = client.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = client.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
Flux<ChatResponse> joke1Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a joke?")));
Flux<ChatResponse> joke2Stream = chatModel.stream(new Prompt(new UserMessage("Tell me a toy joke?")));
String joke1 = joke1Stream.collectList()
.block()
@@ -99,7 +99,7 @@ class BedrockTitanChatClientIT {
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -117,7 +117,7 @@ class BedrockTitanChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors.", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.client.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -138,7 +138,7 @@ class BedrockTitanChatClientIT {
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -162,7 +162,7 @@ class BedrockTitanChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
@@ -184,7 +184,7 @@ class BedrockTitanChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = client.stream(prompt)
String generationTextFromStream = chatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -211,8 +211,8 @@ class BedrockTitanChatClientIT {
}
@Bean
public BedrockTitanChatClient titanChatClient(TitanChatBedrockApi titanApi) {
return new BedrockTitanChatClient(titanApi);
public BedrockTitanChatModel titanChatModel(TitanChatBedrockApi titanApi) {
return new BedrockTitanChatModel(titanApi);
}
}

View File

@@ -22,7 +22,7 @@ import java.util.Map;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
@@ -31,15 +31,17 @@ import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
import org.springframework.ai.huggingface.model.GenerateParameters;
import org.springframework.ai.huggingface.model.GenerateRequest;
import org.springframework.ai.huggingface.model.GenerateResponse;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.ChatOptionsBuilder;
import org.springframework.ai.chat.prompt.Prompt;
/**
* An implementation of {@link ChatClient} that interfaces with HuggingFace Inference
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
* Endpoints for text generation.
*
* @author Mark Pollack
*/
public class HuggingfaceChatClient implements ChatClient {
public class HuggingfaceChatModel implements ChatModel {
/**
* Token required for authenticating with the HuggingFace Inference API.
@@ -68,11 +70,11 @@ public class HuggingfaceChatClient implements ChatClient {
private int maxNewTokens = 1000;
/**
* Constructs a new HuggingfaceChatClient with the specified API token and base path.
* Constructs a new HuggingfaceChatModel with the specified API token and base path.
* @param apiToken The API token for HuggingFace.
* @param basePath The base path for API requests.
*/
public HuggingfaceChatClient(final String apiToken, String basePath) {
public HuggingfaceChatModel(final String apiToken, String basePath) {
this.apiToken = apiToken;
this.apiClient.setBasePath(basePath);
this.apiClient.addDefaultHeader("Authorization", "Bearer " + this.apiToken);
@@ -120,4 +122,9 @@ public class HuggingfaceChatClient implements ChatClient {
this.maxNewTokens = maxNewTokens;
}
@Override
public ChatOptions getDefaultOptions() {
return ChatOptionsBuilder.builder().build();
}
}

View File

@@ -23,7 +23,7 @@ import org.springframework.util.StringUtils;
public class HuggingfaceTestConfiguration {
@Bean
public HuggingfaceChatClient huggingfaceChatClient() {
public HuggingfaceChatModel huggingfaceChatModel() {
String apiKey = System.getenv("HUGGINGFACE_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
@@ -31,9 +31,9 @@ public class HuggingfaceTestConfiguration {
}
// Created aws-mistral-7b-instruct-v0-1-805 via
// https://ui.endpoints.huggingface.co/
HuggingfaceChatClient huggingfaceChatClient = new HuggingfaceChatClient(apiKey,
HuggingfaceChatModel huggingfaceChatModel = new HuggingfaceChatModel(apiKey,
"https://f6hg7b3cvlmntp5i.us-east-1.aws.endpoints.huggingface.cloud");
return huggingfaceChatClient;
return huggingfaceChatModel;
}
}

View File

@@ -20,7 +20,7 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.huggingface.HuggingfaceChatClient;
import org.springframework.ai.huggingface.HuggingfaceChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
@@ -33,7 +33,7 @@ import static org.assertj.core.api.Assertions.assertThat;
public class ClientIT {
@Autowired
protected HuggingfaceChatClient huggingfaceChatClient;
protected HuggingfaceChatModel huggingfaceChatModel;
@Test
void helloWorldCompletion() {
@@ -46,7 +46,7 @@ public class ClientIT {
[/INST]
""";
Prompt prompt = new Prompt(mistral7bInstruct);
ChatResponse chatResponse = huggingfaceChatClient.call(prompt);
ChatResponse chatResponse = huggingfaceChatModel.call(prompt);
assertThat(chatResponse.getResult().getOutput().getContent()).isNotEmpty();
String expectedResponse = """
```json

View File

@@ -15,30 +15,6 @@
*/
package org.springframework.ai.minimax;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.ai.minimax.api.MiniMaxApi.*;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import java.util.Base64;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -47,21 +23,51 @@ import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionFinishReason;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.Role;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionRequest;
import org.springframework.ai.minimax.api.MiniMaxApi.FunctionTool;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* {@link ChatClient} and {@link StreamingChatClient} implementation for
* {@literal MiniMax} backed by {@link MiniMaxApi}.
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal MiniMax}
* backed by {@link MiniMaxApi}.
*
* @author Geng Rong
* @see ChatClient
* @see StreamingChatClient
* @see ChatModel
* @see StreamingChatModel
* @see MiniMaxApi
* @since 1.0.0 M1
*/
public class MiniMaxChatClient extends
AbstractFunctionCallSupport<ChatCompletionMessage, ChatCompletionRequest, ResponseEntity<ChatCompletion>>
implements ChatClient, StreamingChatClient {
private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatClient.class);
public class MiniMaxChatModel extends
AbstractFunctionCallSupport<MiniMaxApi.ChatCompletionMessage, MiniMaxApi.ChatCompletionRequest, ResponseEntity<MiniMaxApi.ChatCompletion>>
implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(MiniMaxChatModel.class);
/**
* The default options used for the chat completion requests.
@@ -79,35 +85,35 @@ public class MiniMaxChatClient extends
private final MiniMaxApi miniMaxApi;
/**
* Creates an instance of the MiniMaxChatClient.
* Creates an instance of the MiniMaxChatModel.
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
* MiniMax Chat API.
* MiniMax Chat API.
* @throws IllegalArgumentException if MiniMaxApi is null
*/
public MiniMaxChatClient(MiniMaxApi miniMaxApi) {
public MiniMaxChatModel(MiniMaxApi miniMaxApi) {
this(miniMaxApi,
MiniMaxChatOptions.builder().withModel(MiniMaxApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build());
}
/**
* Initializes an instance of the MiniMaxChatClient.
* Initializes an instance of the MiniMaxChatModel.
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
* MiniMax Chat API.
* @param options The MiniMaxChatOptions to configure the chat client.
* @param options The MiniMaxChatOptions to configure the chat model.
*/
public MiniMaxChatClient(MiniMaxApi miniMaxApi, MiniMaxChatOptions options) {
public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options) {
this(miniMaxApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Initializes a new instance of the MiniMaxChatClient.
* Initializes a new instance of the MiniMaxChatModel.
* @param miniMaxApi The MiniMaxApi instance to be used for interacting with the
* MiniMax Chat API.
* @param options The MiniMaxChatOptions to configure the chat client.
* @param options The MiniMaxChatOptions to configure the chat model.
* @param functionCallbackContext The function callback context.
* @param retryTemplate The retry template.
*/
public MiniMaxChatClient(MiniMaxApi miniMaxApi, MiniMaxChatOptions options,
public MiniMaxChatModel(MiniMaxApi miniMaxApi, MiniMaxChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
super(functionCallbackContext);
Assert.notNull(miniMaxApi, "MiniMaxApi must not be null");
@@ -279,7 +285,7 @@ public class MiniMaxChatClient extends
return request;
}
private List<FunctionTool> getFunctionTools(Set<String> functionNames) {
private List<MiniMaxApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
var function = new FunctionTool.Function(functionCallback.getDescription(), functionCallback.getName(),
functionCallback.getInputTypeSchema());
@@ -358,4 +364,9 @@ public class MiniMaxChatClient extends
&& choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
}
@Override
public ChatOptions getDefaultOptions() {
return MiniMaxChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -114,10 +114,10 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions {
private @JsonProperty("tool_choice") String toolChoice;
/**
* MiniMax Tool Function Callbacks to register with the ChatClient.
* MiniMax Tool Function Callbacks to register with the ChatModel.
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
* from the registry to be used by the ChatClient chat completion requests.
* from the registry to be used by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
@@ -467,4 +467,22 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions {
return true;
}
public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) {
return builder().withModel(fromOptions.getModel())
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
.withMaxTokens(fromOptions.getMaxTokens())
.withN(fromOptions.getN())
.withPresencePenalty(fromOptions.getPresencePenalty())
.withResponseFormat(fromOptions.getResponseFormat())
.withSeed(fromOptions.getSeed())
.withStop(fromOptions.getStop())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTools(fromOptions.getTools())
.withToolChoice(fromOptions.getToolChoice())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}
}

View File

@@ -19,7 +19,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
@@ -40,9 +40,9 @@ import java.util.List;
* @author Geng Rong
* @since 1.0.0 M1
*/
public class MiniMaxEmbeddingClient extends AbstractEmbeddingClient {
public class MiniMaxEmbeddingModel extends AbstractEmbeddingModel {
private static final Logger logger = LoggerFactory.getLogger(MiniMaxEmbeddingClient.class);
private static final Logger logger = LoggerFactory.getLogger(MiniMaxEmbeddingModel.class);
private final MiniMaxEmbeddingOptions defaultOptions;
@@ -53,43 +53,43 @@ public class MiniMaxEmbeddingClient extends AbstractEmbeddingClient {
private final MetadataMode metadataMode;
/**
* Constructor for the MiniMaxEmbeddingClient class.
* Constructor for the MiniMaxEmbeddingModel class.
* @param miniMaxApi The MiniMaxApi instance to use for making API requests.
*/
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi) {
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi) {
this(miniMaxApi, MetadataMode.EMBED);
}
/**
* Initializes a new instance of the MiniMaxEmbeddingClient class.
* Initializes a new instance of the MiniMaxEmbeddingModel class.
* @param miniMaxApi The MiniMaxApi instance to use for making API requests.
* @param metadataMode The mode for generating metadata.
*/
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi, MetadataMode metadataMode) {
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode) {
this(miniMaxApi, metadataMode,
MiniMaxEmbeddingOptions.builder().withModel(MiniMaxApi.DEFAULT_EMBEDDING_MODEL).build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Initializes a new instance of the MiniMaxEmbeddingClient class.
* Initializes a new instance of the MiniMaxEmbeddingModel class.
* @param miniMaxApi The MiniMaxApi instance to use for making API requests.
* @param metadataMode The mode for generating metadata.
* @param miniMaxEmbeddingOptions The options for MiniMax embedding.
*/
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi, MetadataMode metadataMode,
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode,
MiniMaxEmbeddingOptions miniMaxEmbeddingOptions) {
this(miniMaxApi, metadataMode, miniMaxEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Initializes a new instance of the MiniMaxEmbeddingClient class.
* Initializes a new instance of the MiniMaxEmbeddingModel class.
* @param miniMaxApi - The MiniMaxApi instance to use for making API requests.
* @param metadataMode - The mode for generating metadata.
* @param options - The options for MiniMax embedding.
* @param retryTemplate - The RetryTemplate for retrying failed API requests.
*/
public MiniMaxEmbeddingClient(MiniMaxApi miniMaxApi, MetadataMode metadataMode, MiniMaxEmbeddingOptions options,
public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode, MiniMaxEmbeddingOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(miniMaxApi, "MiniMaxApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");

View File

@@ -19,6 +19,8 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
@@ -111,7 +113,7 @@ public class MiniMaxApi {
* MiniMax Chat Completion Models:
* <a href="https://www.minimaxi.com/document/algorithm-concept">MiniMax Model</a>.
*/
public enum ChatModel {
public enum ChatModel implements ModelDescription {
ABAB_6_Chat("abab6-chat"),
ABAB_5_5_Chat("abab5.5-chat"),
ABAB_5_5_S_Chat("abab5.5s-chat");
@@ -125,6 +127,11 @@ public class MiniMaxApi {
public String getValue() {
return value;
}
@Override
public String getModelName() {
return this.value;
}
}
/**

View File

@@ -33,7 +33,7 @@ public class ChatCompletionRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new MiniMaxChatClient(new MiniMaxApi("TEST"),
var client = new MiniMaxChatModel(new MiniMaxApi("TEST"),
MiniMaxChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
var request = client.createRequest(new Prompt("Test message content"), false);
@@ -59,7 +59,7 @@ public class ChatCompletionRequestTests {
final String TOOL_FUNCTION_NAME = "CurrentWeather";
var client = new MiniMaxChatClient(new MiniMaxApi("TEST"),
var client = new MiniMaxChatModel(new MiniMaxApi("TEST"),
MiniMaxChatOptions.builder().withModel("DEFAULT_MODEL").build());
var request = client.createRequest(new Prompt("Test message content",
@@ -89,7 +89,7 @@ public class ChatCompletionRequestTests {
final String TOOL_FUNCTION_NAME = "CurrentWeather";
var client = new MiniMaxChatClient(new MiniMaxApi("TEST"),
var client = new MiniMaxChatModel(new MiniMaxApi("TEST"),
MiniMaxChatOptions.builder()
.withModel("DEFAULT_MODEL")
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())

View File

@@ -15,7 +15,7 @@
*/
package org.springframework.ai.minimax;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.minimax.api.MiniMaxApi;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
@@ -42,13 +42,13 @@ public class MiniMaxTestConfiguration {
}
@Bean
public MiniMaxChatClient miniMaxChatClient(MiniMaxApi api) {
return new MiniMaxChatClient(api);
public MiniMaxChatModel miniMaxChatModel(MiniMaxApi api) {
return new MiniMaxChatModel(api);
}
@Bean
public EmbeddingClient miniMaxEmbeddingClient(MiniMaxApi api) {
return new MiniMaxEmbeddingClient(api);
public EmbeddingModel miniMaxEmbeddingModel(MiniMaxApi api) {
return new MiniMaxEmbeddingModel(api);
}
}

View File

@@ -22,9 +22,9 @@ import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.minimax.MiniMaxChatClient;
import org.springframework.ai.minimax.MiniMaxChatModel;
import org.springframework.ai.minimax.MiniMaxChatOptions;
import org.springframework.ai.minimax.MiniMaxEmbeddingClient;
import org.springframework.ai.minimax.MiniMaxEmbeddingModel;
import org.springframework.ai.minimax.MiniMaxEmbeddingOptions;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletion;
import org.springframework.ai.minimax.api.MiniMaxApi.ChatCompletionChunk;
@@ -83,9 +83,9 @@ public class MiniMaxRetryTests {
private @Mock MiniMaxApi miniMaxApi;
private MiniMaxChatClient chatClient;
private MiniMaxChatModel chatModel;
private MiniMaxEmbeddingClient embeddingClient;
private MiniMaxEmbeddingModel embeddingModel;
@BeforeEach
public void beforeEach() {
@@ -93,8 +93,8 @@ public class MiniMaxRetryTests {
retryListener = new TestRetryListener();
retryTemplate.registerListener(retryListener);
chatClient = new MiniMaxChatClient(miniMaxApi, MiniMaxChatOptions.builder().build(), null, retryTemplate);
embeddingClient = new MiniMaxEmbeddingClient(miniMaxApi, MetadataMode.EMBED,
chatModel = new MiniMaxChatModel(miniMaxApi, MiniMaxChatOptions.builder().build(), null, retryTemplate);
embeddingModel = new MiniMaxEmbeddingModel(miniMaxApi, MetadataMode.EMBED,
MiniMaxEmbeddingOptions.builder().build(), retryTemplate);
}
@@ -111,7 +111,7 @@ public class MiniMaxRetryTests {
.thenThrow(new TransientAiException("Transient Error 2"))
.thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
var result = chatClient.call(new Prompt("text"));
var result = chatModel.call(new Prompt("text"));
assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getContent()).isSameAs("Response");
@@ -123,7 +123,7 @@ public class MiniMaxRetryTests {
public void miniMaxChatNonTransientError() {
when(miniMaxApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text")));
assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text")));
}
@Test
@@ -139,7 +139,7 @@ public class MiniMaxRetryTests {
.thenThrow(new TransientAiException("Transient Error 2"))
.thenReturn(Flux.just(expectedChatCompletion));
var result = chatClient.stream(new Prompt("text"));
var result = chatModel.stream(new Prompt("text"));
assertThat(result).isNotNull();
assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response");
@@ -151,7 +151,7 @@ public class MiniMaxRetryTests {
public void miniMaxChatStreamNonTransientError() {
when(miniMaxApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
}
@Test
@@ -164,7 +164,7 @@ public class MiniMaxRetryTests {
.thenThrow(new TransientAiException("Transient Error 2"))
.thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));
var result = embeddingClient
var result = embeddingModel
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
assertThat(result).isNotNull();
@@ -177,7 +177,7 @@ public class MiniMaxRetryTests {
public void miniMaxEmbeddingNonTransientError() {
when(miniMaxApi.embeddings(isA(EmbeddingRequest.class)))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> embeddingClient
assertThrows(RuntimeException.class, () -> embeddingModel
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
}

View File

@@ -17,10 +17,10 @@ package org.springframework.ai.mistralai;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
@@ -55,9 +55,9 @@ import java.util.concurrent.ConcurrentHashMap;
* @author Grogdunn
* @since 0.8.1
*/
public class MistralAiChatClient extends
public class MistralAiChatModel extends
AbstractFunctionCallSupport<MistralAiApi.ChatCompletionMessage, MistralAiApi.ChatCompletionRequest, ResponseEntity<MistralAiApi.ChatCompletion>>
implements ChatClient, StreamingChatClient {
implements ChatModel, StreamingChatModel {
private final Logger log = LoggerFactory.getLogger(getClass());
@@ -73,7 +73,7 @@ public class MistralAiChatClient extends
private final RetryTemplate retryTemplate;
public MistralAiChatClient(MistralAiApi mistralAiApi) {
public MistralAiChatModel(MistralAiApi mistralAiApi) {
this(mistralAiApi,
MistralAiChatOptions.builder()
.withTemperature(0.7f)
@@ -83,11 +83,11 @@ public class MistralAiChatClient extends
.build());
}
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options,
public MistralAiChatModel(MistralAiApi mistralAiApi, MistralAiChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
super(functionCallbackContext);
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
@@ -324,4 +324,9 @@ public class MistralAiChatClient extends
return !CollectionUtils.isEmpty(choices.get(0).message().toolCalls());
}
@Override
public ChatOptions getDefaultOptions() {
return MistralAiChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -101,11 +101,11 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
private @JsonProperty("tool_choice") ToolChoice toolChoice;
/**
* MistralAI Tool Function Callbacks to register with the ChatClient. For Prompt
* MistralAI Tool Function Callbacks to register with the ChatModel. For Prompt
* Options the functionCallbacks are automatically enabled for the duration of the
* prompt execution. For Default Options the functionCallbacks are registered but
* disabled by default. Use the enableFunctions to set the functions from the registry
* to be used by the ChatClient chat completion requests.
* to be used by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
@@ -139,7 +139,7 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
return this;
}
public Builder withMaxToken(Integer maxTokens) {
public Builder withMaxTokens(Integer maxTokens) {
this.options.setMaxTokens(maxTokens);
return this;
}
@@ -309,4 +309,19 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
this.functions = functions;
}
public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) {
return builder().withModel(fromOptions.getModel())
.withMaxTokens(fromOptions.getMaxTokens())
.withSafePrompt(fromOptions.getSafePrompt())
.withRandomSeed(fromOptions.getRandomSeed())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withResponseFormat(fromOptions.getResponseFormat())
.withTools(fromOptions.getTools())
.withToolChoice(fromOptions.getToolChoice())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}
}

View File

@@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
@@ -38,7 +38,7 @@ import org.springframework.util.Assert;
* @author Ricken Bazolo
* @since 0.8.1
*/
public class MistralAiEmbeddingClient extends AbstractEmbeddingClient {
public class MistralAiEmbeddingModel extends AbstractEmbeddingModel {
private final Logger log = LoggerFactory.getLogger(getClass());
@@ -50,21 +50,21 @@ public class MistralAiEmbeddingClient extends AbstractEmbeddingClient {
private final RetryTemplate retryTemplate;
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) {
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi) {
this(mistralAiApi, MetadataMode.EMBED);
}
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
this(mistralAiApi, metadataMode,
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) {
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) {
this(mistralAiApi, MetadataMode.EMBED, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode,
public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode,
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate) {
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");

View File

@@ -27,6 +27,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
@@ -706,7 +707,7 @@ public class MistralAiApi {
* <li><b>LARGE</b> - mistral-large-latest (aka mistral-large-2402)</li>
* </ul>
*/
public enum ChatModel {
public enum ChatModel implements ModelDescription {
// @formatter:off
TINY("open-mistral-7b"),
@@ -726,6 +727,11 @@ public class MistralAiApi {
return this.value;
}
@Override
public String getModelName() {
return this.value;
}
}
/**

View File

@@ -32,12 +32,12 @@ import static org.assertj.core.api.Assertions.assertThat;
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
public class MistralAiChatCompletionRequestTest {
MistralAiChatClient chatClient = new MistralAiChatClient(new MistralAiApi("test"));
MistralAiChatModel chatModel = new MistralAiChatModel(new MistralAiApi("test"));
@Test
void chatCompletionDefaultRequestTest() {
var request = chatClient.createRequest(new Prompt("test content"), false);
var request = chatModel.createRequest(new Prompt("test content"), false);
assertThat(request.messages()).hasSize(1);
assertThat(request.topP()).isEqualTo(1);
@@ -52,7 +52,7 @@ public class MistralAiChatCompletionRequestTest {
var options = MistralAiChatOptions.builder().withTemperature(0.5f).withTopP(0.8f).build();
var request = chatClient.createRequest(new Prompt("test content", options), true);
var request = chatModel.createRequest(new Prompt("test content", options), true);
assertThat(request.messages().size()).isEqualTo(1);
assertThat(request.topP()).isEqualTo(0.8f);

View File

@@ -27,10 +27,10 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
@@ -56,15 +56,15 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
@SpringBootTest(classes = MistralAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
class MistralAiChatClientIT {
class MistralAiChatModelIT {
private static final Logger logger = LoggerFactory.getLogger(MistralAiChatClientIT.class);
private static final Logger logger = LoggerFactory.getLogger(MistralAiChatModelIT.class);
@Autowired
protected ChatClient chatClient;
protected ChatModel chatModel;
@Autowired
protected StreamingChatClient streamingChatClient;
protected StreamingChatModel streamingChatModel;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@@ -90,7 +90,7 @@ class MistralAiChatClientIT {
// NOTE: Mistral expects the system message to be before the user message or will
// fail with 400 error.
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatClient.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
}
@@ -108,7 +108,7 @@ class MistralAiChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.chatClient.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -126,7 +126,7 @@ class MistralAiChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -148,7 +148,7 @@ class MistralAiChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
logger.info("" + actorsFilms);
@@ -169,7 +169,7 @@ class MistralAiChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = streamingChatClient.stream(prompt)
String generationTextFromStream = streamingChatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -201,7 +201,7 @@ class MistralAiChatClientIT {
.build()))
.build();
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
logger.info("Response: {}", response);
@@ -224,7 +224,7 @@ class MistralAiChatClientIT {
.build()))
.build();
Flux<ChatResponse> response = streamingChatClient.stream(new Prompt(messages, promptOptions));
Flux<ChatResponse> response = streamingChatModel.stream(new Prompt(messages, promptOptions));
String content = response.collectList()
.block()

View File

@@ -31,25 +31,25 @@ import static org.assertj.core.api.Assertions.assertThat;
class MistralAiEmbeddingIT {
@Autowired
private MistralAiEmbeddingClient mistralAiEmbeddingClient;
private MistralAiEmbeddingModel mistralAiEmbeddingModel;
@Test
void defaultEmbedding() {
assertThat(mistralAiEmbeddingClient).isNotNull();
var embeddingResponse = mistralAiEmbeddingClient.embedForResponse(List.of("Hello World"));
assertThat(mistralAiEmbeddingModel).isNotNull();
var embeddingResponse = mistralAiEmbeddingModel.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1024);
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "mistral-embed");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 4);
assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 4);
assertThat(mistralAiEmbeddingClient.dimensions()).isEqualTo(1024);
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
}
@Test
void embeddingTest() {
assertThat(mistralAiEmbeddingClient).isNotNull();
var embeddingResponse = mistralAiEmbeddingClient.call(new EmbeddingRequest(
assertThat(mistralAiEmbeddingModel).isNotNull();
var embeddingResponse = mistralAiEmbeddingModel.call(new EmbeddingRequest(
List.of("Hello World", "World is big"),
MistralAiEmbeddingOptions.builder().withModel("mistral-embed").withEncodingFormat("float").build()));
assertThat(embeddingResponse.getResults()).hasSize(2);
@@ -58,7 +58,7 @@ class MistralAiEmbeddingIT {
assertThat(embeddingResponse.getMetadata()).containsEntry("model", "mistral-embed");
assertThat(embeddingResponse.getMetadata()).containsEntry("total-tokens", 9);
assertThat(embeddingResponse.getMetadata()).containsEntry("prompt-tokens", 9);
assertThat(mistralAiEmbeddingClient.dimensions()).isEqualTo(1024);
assertThat(mistralAiEmbeddingModel.dimensions()).isEqualTo(1024);
}
}

View File

@@ -82,9 +82,9 @@ public class MistralAiRetryTests {
private @Mock MistralAiApi mistralAiApi;
private MistralAiChatClient chatClient;
private MistralAiChatModel chatModel;
private MistralAiEmbeddingClient embeddingClient;
private MistralAiEmbeddingModel embeddingModel;
@BeforeEach
public void beforeEach() {
@@ -92,7 +92,7 @@ public class MistralAiRetryTests {
retryListener = new TestRetryListener();
retryTemplate.registerListener(retryListener);
chatClient = new MistralAiChatClient(mistralAiApi,
chatModel = new MistralAiChatModel(mistralAiApi,
MistralAiChatOptions.builder()
.withTemperature(0.7f)
.withTopP(1f)
@@ -100,7 +100,7 @@ public class MistralAiRetryTests {
.withModel(MistralAiApi.ChatModel.TINY.getValue())
.build(),
null, retryTemplate);
embeddingClient = new MistralAiEmbeddingClient(mistralAiApi, MetadataMode.EMBED,
embeddingModel = new MistralAiEmbeddingModel(mistralAiApi, MetadataMode.EMBED,
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
retryTemplate);
}
@@ -118,7 +118,7 @@ public class MistralAiRetryTests {
.thenThrow(new TransientAiException("Transient Error 2"))
.thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
var result = chatClient.call(new Prompt("text"));
var result = chatModel.call(new Prompt("text"));
assertThat(result).isNotNull();
assertThat(result.getResult().getOutput().getContent()).isSameAs("Response");
@@ -130,7 +130,7 @@ public class MistralAiRetryTests {
public void mistralAiChatNonTransientError() {
when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text")));
assertThrows(RuntimeException.class, () -> chatModel.call(new Prompt("text")));
}
@Test
@@ -146,7 +146,7 @@ public class MistralAiRetryTests {
.thenThrow(new TransientAiException("Transient Error 2"))
.thenReturn(Flux.just(expectedChatCompletion));
var result = chatClient.stream(new Prompt("text"));
var result = chatModel.stream(new Prompt("text"));
assertThat(result).isNotNull();
assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response");
@@ -158,7 +158,7 @@ public class MistralAiRetryTests {
public void mistralAiChatStreamNonTransientError() {
when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")));
assertThrows(RuntimeException.class, () -> chatModel.stream(new Prompt("text")));
}
@Test
@@ -172,7 +172,7 @@ public class MistralAiRetryTests {
.thenThrow(new TransientAiException("Transient Error 2"))
.thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));
var result = embeddingClient
var result = embeddingModel
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
assertThat(result).isNotNull();
@@ -185,7 +185,7 @@ public class MistralAiRetryTests {
public void mistralAiEmbeddingNonTransientError() {
when(mistralAiApi.embeddings(isA(EmbeddingRequest.class)))
.thenThrow(new RuntimeException("Non Transient Error"));
assertThrows(RuntimeException.class, () -> embeddingClient
assertThrows(RuntimeException.class, () -> embeddingModel
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
}

View File

@@ -15,7 +15,7 @@
*/
package org.springframework.ai.mistralai;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
@@ -35,14 +35,14 @@ public class MistralAiTestConfiguration {
}
@Bean
public EmbeddingClient mistralAiEmbeddingClient(MistralAiApi api) {
return new MistralAiEmbeddingClient(api,
public EmbeddingModel mistralAiEmbeddingModel(MistralAiApi api) {
return new MistralAiEmbeddingModel(api,
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build());
}
@Bean
public MistralAiChatClient mistralAiChatClient(MistralAiApi mistralAiApi) {
return new MistralAiChatClient(mistralAiApi,
public MistralAiChatModel mistralAiChatModel(MistralAiApi mistralAiApi) {
return new MistralAiChatModel(mistralAiApi,
MistralAiChatOptions.builder().withModel(MistralAiApi.ChatModel.MIXTRAL.getValue()).build());
}

View File

@@ -18,13 +18,13 @@ package org.springframework.ai.ollama;
import java.util.Base64;
import java.util.List;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.ollama.metadata.OllamaChatResponseMetadata;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
@@ -39,7 +39,7 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/**
* {@link ChatClient} implementation for {@literal Ollama}.
* {@link ChatModel} implementation for {@literal Ollama}.
*
* Ollama allows developers to run large language models and generate embeddings locally.
* It supports open-source models available on [Ollama AI
@@ -52,7 +52,7 @@ import org.springframework.util.StringUtils;
* @author Christian Tzolov
* @since 0.8.0
*/
public class OllamaChatClient implements ChatClient, StreamingChatClient {
public class OllamaChatModel implements ChatModel, StreamingChatModel {
/**
* Low-level Ollama API library.
@@ -64,11 +64,11 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
*/
private OllamaOptions defaultOptions;
public OllamaChatClient(OllamaApi chatApi) {
public OllamaChatModel(OllamaApi chatApi) {
this(chatApi, OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL));
}
public OllamaChatClient(OllamaApi chatApi, OllamaOptions defaultOptions) {
public OllamaChatModel(OllamaApi chatApi, OllamaOptions defaultOptions) {
Assert.notNull(chatApi, "OllamaApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.chatApi = chatApi;
@@ -79,7 +79,7 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
* @deprecated Use {@link OllamaOptions#setModel} instead.
*/
@Deprecated
public OllamaChatClient withModel(String model) {
public OllamaChatModel withModel(String model) {
this.defaultOptions.setModel(model);
return this;
}
@@ -88,7 +88,7 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
* @deprecated Use {@link OllamaOptions} constructor instead.
*/
@Deprecated
public OllamaChatClient withDefaultOptions(OllamaOptions options) {
public OllamaChatModel withDefaultOptions(OllamaOptions options) {
this.defaultOptions = options;
return this;
}
@@ -205,4 +205,9 @@ public class OllamaChatClient implements ChatClient, StreamingChatClient {
}
}
@Override
public ChatOptions getDefaultOptions() {
return OllamaOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -23,9 +23,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.model.ModelOptionsUtils;
@@ -36,7 +36,7 @@ import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* {@link EmbeddingClient} implementation for {@literal Ollama}.
* {@link EmbeddingModel} implementation for {@literal Ollama}.
*
* Ollama allows developers to run large language models and generate embeddings locally.
* It supports open-source models available on [Ollama AI
@@ -51,7 +51,7 @@ import org.springframework.util.StringUtils;
* @author Christian Tzolov
* @since 0.8.0
*/
public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
public class OllamaEmbeddingModel extends AbstractEmbeddingModel {
private final Logger logger = LoggerFactory.getLogger(getClass());
@@ -62,11 +62,11 @@ public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
*/
private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
public OllamaEmbeddingClient(OllamaApi ollamaApi) {
public OllamaEmbeddingModel(OllamaApi ollamaApi) {
this.ollamaApi = ollamaApi;
}
public OllamaEmbeddingClient(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
this.ollamaApi = ollamaApi;
this.defaultOptions = defaultOptions;
}
@@ -75,7 +75,7 @@ public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
* @deprecated Use {@link OllamaOptions#setModel} instead.
*/
@Deprecated
public OllamaEmbeddingClient withModel(String model) {
public OllamaEmbeddingModel withModel(String model) {
this.defaultOptions.setModel(model);
return this;
}
@@ -84,7 +84,7 @@ public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
* @deprecated Use {@link OllamaOptions} constructor instead.
*/
@Deprecated
public OllamaEmbeddingClient withDefaultOptions(OllamaOptions options) {
public OllamaEmbeddingModel withDefaultOptions(OllamaOptions options) {
this.defaultOptions = options;
return this;
}

View File

@@ -15,13 +15,15 @@
*/
package org.springframework.ai.ollama.api;
import org.springframework.ai.model.ModelDescription;
/**
* Helper class for common Ollama models.
*
* @author Siarhei Blashuk
* @since 0.8.1
*/
public enum OllamaModel {
public enum OllamaModel implements ModelDescription {
/**
* Llama 2 is a collection of language models ranging from 7B to 70B parameters.
@@ -99,4 +101,9 @@ public enum OllamaModel {
return this.id;
}
@Override
public String getModelName() {
return this.id;
}
}

View File

@@ -714,6 +714,43 @@ public class OllamaOptions implements ChatOptions, EmbeddingOptions {
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
return new OllamaOptions()
.withModel(fromOptions.getModel())
.withFormat(fromOptions.getFormat())
.withKeepAlive(fromOptions.getKeepAlive())
.withUseNUMA(fromOptions.getUseNUMA())
.withNumCtx(fromOptions.getNumCtx())
.withNumBatch(fromOptions.getNumBatch())
.withNumGQA(fromOptions.getNumGQA())
.withNumGPU(fromOptions.getNumGPU())
.withMainGPU(fromOptions.getMainGPU())
.withLowVRAM(fromOptions.getLowVRAM())
.withF16KV(fromOptions.getF16KV())
.withLogitsAll(fromOptions.getLogitsAll())
.withVocabOnly(fromOptions.getVocabOnly())
.withUseMMap(fromOptions.getUseMMap())
.withUseMLock(fromOptions.getUseMLock())
.withNumThread(fromOptions.getNumThread())
.withNumKeep(fromOptions.getNumKeep())
.withSeed(fromOptions.getSeed())
.withNumPredict(fromOptions.getNumPredict())
.withTopK(fromOptions.getTopK())
.withTopP(fromOptions.getTopP())
.withTfsZ(fromOptions.getTfsZ())
.withTypicalP(fromOptions.getTypicalP())
.withRepeatLastN(fromOptions.getRepeatLastN())
.withTemperature(fromOptions.getTemperature())
.withRepeatPenalty(fromOptions.getRepeatPenalty())
.withPresencePenalty(fromOptions.getPresencePenalty())
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
.withMirostat(fromOptions.getMirostat())
.withMirostatTau(fromOptions.getMirostatTau())
.withMirostatEta(fromOptions.getMirostatEta())
.withPenalizeNewline(fromOptions.getPenalizeNewline())
.withStop(fromOptions.getStop());
}
// @formatter:on

View File

@@ -56,11 +56,11 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@Testcontainers
@Disabled("For manual smoke testing only.")
class OllamaChatClientIT {
class OllamaChatModelIT {
private static String MODEL = "mistral";
private static final Log logger = LogFactory.getLog(OllamaChatClientIT.class);
private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class);
@Container
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.1.32");
@@ -77,7 +77,7 @@ class OllamaChatClientIT {
}
@Autowired
private OllamaChatClient client;
private OllamaChatModel chatModel;
@Test
void roleTest() {
@@ -95,13 +95,13 @@ class OllamaChatClientIT {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), portableOptions);
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
// ollama specific options
var ollamaOptions = new OllamaOptions().withLowVRAM(true);
response = client.call(new Prompt(List.of(userMessage, systemMessage), ollamaOptions));
response = chatModel.call(new Prompt(List.of(userMessage, systemMessage), ollamaOptions));
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");
}
@@ -109,7 +109,7 @@ class OllamaChatClientIT {
@Test
void usageTest() {
Prompt prompt = new Prompt("Tell me a joke");
ChatResponse response = client.call(prompt);
ChatResponse response = chatModel.call(prompt);
Usage usage = response.getMetadata().getUsage();
assertThat(usage).isNotNull();
@@ -131,7 +131,7 @@ class OllamaChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors.", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.client.call(prompt).getResult();
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@@ -151,7 +151,7 @@ class OllamaChatClientIT {
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@@ -173,7 +173,7 @@ class OllamaChatClientIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = client.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
@@ -194,7 +194,7 @@ class OllamaChatClientIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = client.stream(prompt)
String generationTextFromStream = chatModel.stream(prompt)
.collectList()
.block()
.stream()
@@ -219,8 +219,8 @@ class OllamaChatClientIT {
}
@Bean
public OllamaChatClient ollamaChat(OllamaApi ollamaApi) {
return new OllamaChatClient(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
}
}

View File

@@ -44,11 +44,11 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@Testcontainers
@Disabled("For manual smoke testing only.")
class OllamaChatClientMultimodalIT {
class OllamaChatModelMultimodalIT {
private static String MODEL = "llava";
private static final Log logger = LogFactory.getLog(OllamaChatClientIT.class);
private static final Log logger = LogFactory.getLog(OllamaChatModelIT.class);
@Container
static OllamaContainer ollamaContainer = new OllamaContainer("ollama/ollama:0.1.32");
@@ -65,7 +65,7 @@ class OllamaChatClientMultimodalIT {
}
@Autowired
private OllamaChatClient client;
private OllamaChatModel chatModel;
@Test
void multiModalityTest() throws IOException {
@@ -75,7 +75,7 @@ class OllamaChatClientMultimodalIT {
var userMessage = new UserMessage("Explain what do you see on this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
var response = client.call(new Prompt(List.of(userMessage)));
var response = chatModel.call(new Prompt(List.of(userMessage)));
logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "basket");
@@ -90,8 +90,8 @@ class OllamaChatClientMultimodalIT {
}
@Bean
public OllamaChatClient ollamaChat(OllamaApi ollamaApi) {
return new OllamaChatClient(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
return new OllamaChatModel(ollamaApi, OllamaOptions.create().withModel(MODEL).withTemperature(0.9f));
}
}

View File

@@ -30,13 +30,13 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class OllamaChatRequestTests {
OllamaChatClient client = new OllamaChatClient(new OllamaApi(),
OllamaChatModel chatModel = new OllamaChatModel(new OllamaApi(),
new OllamaOptions().withModel("MODEL_NAME").withTopK(99).withTemperature(66.6f).withNumGPU(1));
@Test
public void createRequestWithDefaultOptions() {
var request = client.ollamaChatRequest(new Prompt("Test message content"), false);
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), false);
assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isFalse();
@@ -54,7 +54,7 @@ public class OllamaChatRequestTests {
// Runtime options should override the default options.
OllamaOptions promptOptions = new OllamaOptions().withTemperature(0.8f).withTopP(0.5f).withNumGPU(2);
var request = client.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();
@@ -79,7 +79,7 @@ public class OllamaChatRequestTests {
.withTopP(0.6f)
.build();
var request = client.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true);
var request = chatModel.ollamaChatRequest(new Prompt("Test message content", portablePromptOptions), true);
assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();
@@ -97,7 +97,7 @@ public class OllamaChatRequestTests {
// Ollama runtime options.
OllamaOptions promptOptions = new OllamaOptions().withModel("PROMPT_MODEL");
var request = client.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
var request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
}
@@ -105,17 +105,17 @@ public class OllamaChatRequestTests {
@Test
public void createRequestWithDefaultOptionsModelOverride() {
OllamaChatClient client2 = new OllamaChatClient(new OllamaApi(),
OllamaChatModel chatModel = new OllamaChatModel(new OllamaApi(),
new OllamaOptions().withModel("DEFAULT_OPTIONS_MODEL"));
var request = client2.ollamaChatRequest(new Prompt("Test message content"), true);
var request = chatModel.ollamaChatRequest(new Prompt("Test message content"), true);
assertThat(request.model()).isEqualTo("DEFAULT_OPTIONS_MODEL");
// Prompt options should override the default options.
OllamaOptions promptOptions = new OllamaOptions().withModel("PROMPT_MODEL");
request = client2.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
request = chatModel.ollamaChatRequest(new Prompt("Test message content", promptOptions), true);
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
}

View File

@@ -23,9 +23,9 @@ import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.ollama.OllamaContainer;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.api.OllamaApi;
@@ -34,14 +34,13 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.testcontainers.ollama.OllamaContainer;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
@Disabled("For manual smoke testing only.")
@Testcontainers
class OllamaEmbeddingClientIT {
class OllamaEmbeddingModelIT {
private static final Log logger = LogFactory.getLog(OllamaApiIT.class);
@@ -60,15 +59,15 @@ class OllamaEmbeddingClientIT {
}
@Autowired
private OllamaEmbeddingClient embeddingClient;
private OllamaEmbeddingModel embeddingModel;
@Test
void singleEmbedding() {
assertThat(embeddingClient).isNotNull();
EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello World"));
assertThat(embeddingModel).isNotNull();
EmbeddingResponse embeddingResponse = embeddingModel.embedForResponse(List.of("Hello World"));
assertThat(embeddingResponse.getResults()).hasSize(1);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingClient.dimensions()).isEqualTo(3200);
assertThat(embeddingModel.dimensions()).isEqualTo(3200);
}
@SpringBootConfiguration
@@ -80,8 +79,8 @@ class OllamaEmbeddingClientIT {
}
@Bean
public OllamaEmbeddingClient ollamaEmbedding(OllamaApi ollamaApi) {
return new OllamaEmbeddingClient(ollamaApi).withModel("orca-mini");
public OllamaEmbeddingModel ollamaEmbedding(OllamaApi ollamaApi) {
return new OllamaEmbeddingModel(ollamaApi).withModel("orca-mini");
}
}

View File

@@ -28,13 +28,13 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class OllamaEmbeddingRequestTests {
OllamaEmbeddingClient client = new OllamaEmbeddingClient(new OllamaApi()).withDefaultOptions(
OllamaEmbeddingModel chatModel = new OllamaEmbeddingModel(new OllamaApi()).withDefaultOptions(
new OllamaOptions().withModel("DEFAULT_MODEL").withMainGPU(11).withUseMMap(true).withNumGPU(1));
@Test
public void ollamaEmbeddingRequestDefaultOptions() {
var request = client.ollamaEmbeddingRequest("Hello", null);
var request = chatModel.ollamaEmbeddingRequest("Hello", null);
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.options().get("num_gpu")).isEqualTo(1);
@@ -51,7 +51,7 @@ public class OllamaEmbeddingRequestTests {
.withUseMMap(true)
.withNumGPU(2);
var request = client.ollamaEmbeddingRequest("Hello", promptOptions);
var request = chatModel.ollamaEmbeddingRequest("Hello", promptOptions);
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
assertThat(request.options().get("num_gpu")).isEqualTo(2);

View File

@@ -24,10 +24,10 @@ import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
import org.springframework.ai.openai.api.common.OpenAiApiException;
import org.springframework.ai.openai.audio.speech.Speech;
import org.springframework.ai.openai.audio.speech.SpeechClient;
import org.springframework.ai.openai.audio.speech.SpeechModel;
import org.springframework.ai.openai.audio.speech.SpeechPrompt;
import org.springframework.ai.openai.audio.speech.SpeechResponse;
import org.springframework.ai.openai.audio.speech.StreamingSpeechClient;
import org.springframework.ai.openai.audio.speech.StreamingSpeechModel;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.http.ResponseEntity;
@@ -44,7 +44,7 @@ import java.time.Duration;
* @see OpenAiAudioApi
* @since 1.0.0-M1
*/
public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechClient {
public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel {
private final Logger logger = LoggerFactory.getLogger(getClass());
@@ -61,12 +61,12 @@ public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechCli
private final OpenAiAudioApi audioApi;
/**
* Initializes a new instance of the OpenAiAudioSpeechClient class with the provided
* Initializes a new instance of the OpenAiAudioSpeechModel class with the provided
* OpenAiAudioApi. It uses the model tts-1, response format mp3, voice alloy, and the
* default speed of 1.0.
* @param audioApi The OpenAiAudioApi to use for speech synthesis.
*/
public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi) {
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi) {
this(audioApi,
OpenAiAudioSpeechOptions.builder()
.withModel(OpenAiAudioApi.TtsModel.TTS_1.getValue())
@@ -77,13 +77,13 @@ public class OpenAiAudioSpeechClient implements SpeechClient, StreamingSpeechCli
}
/**
* Initializes a new instance of the OpenAiAudioSpeechClient class with the provided
* Initializes a new instance of the OpenAiAudioSpeechModel class with the provided
* OpenAiAudioApi and options.
* @param audioApi The OpenAiAudioApi to use for speech synthesis.
* @param options The OpenAiAudioSpeechOptions containing the speech synthesis
* options.
*/
public OpenAiAudioSpeechClient(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions options) {
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
Assert.notNull(options, "OpenAiSpeechOptions must not be null");
this.audioApi = audioApi;

View File

@@ -35,7 +35,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.ModelClient;
import org.springframework.ai.model.Model;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
import org.springframework.ai.openai.audio.transcription.AudioTranscription;
@@ -59,8 +59,7 @@ import org.springframework.util.Assert;
* @see OpenAiAudioApi
* @since 0.8.1
*/
public class OpenAiAudioTranscriptionClient
implements ModelClient<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
public class OpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
private final Logger logger = LoggerFactory.getLogger(getClass());
@@ -71,11 +70,11 @@ public class OpenAiAudioTranscriptionClient
private final OpenAiAudioApi audioApi;
/**
* OpenAiAudioTranscriptionClient is a client class used to interact with the OpenAI
* OpenAiAudioTranscriptionModel is a client class used to interact with the OpenAI
* Audio Transcription API.
* @param audioApi The OpenAiAudioApi instance to be used for making API calls.
*/
public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi) {
public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi) {
this(audioApi,
OpenAiAudioTranscriptionOptions.builder()
.withModel(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue())
@@ -86,25 +85,25 @@ public class OpenAiAudioTranscriptionClient
}
/**
* OpenAiAudioTranscriptionClient is a client class used to interact with the OpenAI
* OpenAiAudioTranscriptionModel is a client class used to interact with the OpenAI
* Audio Transcription API.
* @param audioApi The OpenAiAudioApi instance to be used for making API calls.
* @param options The OpenAiAudioTranscriptionOptions instance for configuring the
* audio transcription.
*/
public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options) {
public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options) {
this(audioApi, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* OpenAiAudioTranscriptionClient is a client class used to interact with the OpenAI
* OpenAiAudioTranscriptionModel is a client class used to interact with the OpenAI
* Audio Transcription API.
* @param audioApi The OpenAiAudioApi instance to be used for making API calls.
* @param options The OpenAiAudioTranscriptionOptions instance for configuring the
* audio transcription.
* @param retryTemplate The RetryTemplate instance for retrying failed API calls.
*/
public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options,
public OpenAiAudioTranscriptionModel(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
Assert.notNull(options, "OpenAiTranscriptionOptions must not be null");

View File

@@ -17,10 +17,10 @@ package org.springframework.ai.openai;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.StreamingChatModel;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.chat.prompt.ChatOptions;
@@ -58,7 +58,7 @@ import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* {@link ChatClient} and {@link StreamingChatClient} implementation for {@literal OpenAI}
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
* backed by {@link OpenAiApi}.
*
* @author Mark Pollack
@@ -68,15 +68,15 @@ import java.util.concurrent.ConcurrentHashMap;
* @author Josh Long
* @author Jemin Huh
* @author Grogdunn
* @see ChatClient
* @see StreamingChatClient
* @see ChatModel
* @see StreamingChatModel
* @see OpenAiApi
*/
public class OpenAiChatClient extends
public class OpenAiChatModel extends
AbstractFunctionCallSupport<ChatCompletionMessage, OpenAiApi.ChatCompletionRequest, ResponseEntity<ChatCompletion>>
implements ChatClient, StreamingChatClient {
implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClient.class);
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModel.class);
/**
* The default options used for the chat completion requests.
@@ -94,35 +94,35 @@ public class OpenAiChatClient extends
private final OpenAiApi openAiApi;
/**
* Creates an instance of the OpenAiChatClient.
* Creates an instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @throws IllegalArgumentException if openAiApi is null
*/
public OpenAiChatClient(OpenAiApi openAiApi) {
public OpenAiChatModel(OpenAiApi openAiApi) {
this(openAiApi,
OpenAiChatOptions.builder().withModel(OpenAiApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build());
}
/**
* Initializes an instance of the OpenAiChatClient.
* Initializes an instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @param options The OpenAiChatOptions to configure the chat client.
* @param options The OpenAiChatOptions to configure the chat model.
*/
public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options) {
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options) {
this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Initializes a new instance of the OpenAiChatClient.
* Initializes a new instance of the OpenAiChatModel.
* @param openAiApi The OpenAiApi instance to be used for interacting with the OpenAI
* Chat API.
* @param options The OpenAiChatOptions to configure the chat client.
* @param options The OpenAiChatOptions to configure the chat model.
* @param functionCallbackContext The function callback context.
* @param retryTemplate The retry template.
*/
public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options,
public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions options,
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
super(functionCallbackContext);
Assert.notNull(openAiApi, "OpenAiApi must not be null");
@@ -394,4 +394,9 @@ public class OpenAiChatClient extends
&& choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
}
@Override
public ChatOptions getDefaultOptions() {
return OpenAiChatOptions.fromOptions(this.defaultOptions);
}
}

View File

@@ -134,10 +134,10 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
private @JsonProperty("user") String user;
/**
* OpenAI Tool Function Callbacks to register with the ChatClient.
* OpenAI Tool Function Callbacks to register with the ChatModel.
* For Prompt Options the functionCallbacks are automatically enabled for the duration of the prompt execution.
* For Default Options the functionCallbacks are registered but disabled by default. Use the enableFunctions to set the functions
* from the registry to be used by the ChatClient chat completion requests.
* from the registry to be used by the ChatModel chat completion requests.
*/
@NestedConfigurationProperty
@JsonIgnore
@@ -567,4 +567,27 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
}
public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
return OpenAiChatOptions.builder()
.withModel(fromOptions.getModel())
.withFrequencyPenalty(fromOptions.getFrequencyPenalty())
.withLogitBias(fromOptions.getLogitBias())
.withLogprobs(fromOptions.getLogprobs())
.withTopLogprobs(fromOptions.getTopLogprobs())
.withMaxTokens(fromOptions.getMaxTokens())
.withN(fromOptions.getN())
.withPresencePenalty(fromOptions.getPresencePenalty())
.withResponseFormat(fromOptions.getResponseFormat())
.withSeed(fromOptions.getSeed())
.withStop(fromOptions.getStop())
.withTemperature(fromOptions.getTemperature())
.withTopP(fromOptions.getTopP())
.withTools(fromOptions.getTools())
.withToolChoice(fromOptions.getToolChoice())
.withUser(fromOptions.getUser())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.build();
}
}

View File

@@ -22,7 +22,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
@@ -41,9 +41,9 @@ import org.springframework.util.Assert;
*
* @author Christian Tzolov
*/
public class OpenAiEmbeddingClient extends AbstractEmbeddingClient {
public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {
private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingClient.class);
private static final Logger logger = LoggerFactory.getLogger(OpenAiEmbeddingModel.class);
private final OpenAiEmbeddingOptions defaultOptions;
@@ -54,43 +54,43 @@ public class OpenAiEmbeddingClient extends AbstractEmbeddingClient {
private final MetadataMode metadataMode;
/**
* Constructor for the OpenAiEmbeddingClient class.
* Constructor for the OpenAiEmbeddingModel class.
* @param openAiApi The OpenAiApi instance to use for making API requests.
*/
public OpenAiEmbeddingClient(OpenAiApi openAiApi) {
public OpenAiEmbeddingModel(OpenAiApi openAiApi) {
this(openAiApi, MetadataMode.EMBED);
}
/**
* Initializes a new instance of the OpenAiEmbeddingClient class.
* Initializes a new instance of the OpenAiEmbeddingModel class.
* @param openAiApi The OpenAiApi instance to use for making API requests.
* @param metadataMode The mode for generating metadata.
*/
public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode) {
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode) {
this(openAiApi, metadataMode,
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Initializes a new instance of the OpenAiEmbeddingClient class.
* Initializes a new instance of the OpenAiEmbeddingModel class.
* @param openAiApi The OpenAiApi instance to use for making API requests.
* @param metadataMode The mode for generating metadata.
* @param openAiEmbeddingOptions The options for OpenAi embedding.
*/
public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode,
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode,
OpenAiEmbeddingOptions openAiEmbeddingOptions) {
this(openAiApi, metadataMode, openAiEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
/**
* Initializes a new instance of the OpenAiEmbeddingClient class.
* Initializes a new instance of the OpenAiEmbeddingModel class.
* @param openAiApi - The OpenAiApi instance to use for making API requests.
* @param metadataMode - The mode for generating metadata.
* @param options - The options for OpenAI embedding.
* @param retryTemplate - The RetryTemplate for retrying failed API requests.
*/
public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
RetryTemplate retryTemplate) {
Assert.notNull(openAiApi, "OpenAiService must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");

View File

@@ -21,7 +21,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.Image;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
@@ -37,16 +37,16 @@ import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
/**
* OpenAiImageClient is a class that implements the ImageClient interface. It provides a
* OpenAiImageModel is a class that implements the ImageModel interface. It provides a
* client for calling the OpenAI image generation API.
*
* @author Mark Pollack
* @author Christian Tzolov
* @since 0.8.0
*/
public class OpenAiImageClient implements ImageClient {
public class OpenAiImageModel implements ImageModel {
private final static Logger logger = LoggerFactory.getLogger(OpenAiImageClient.class);
private final static Logger logger = LoggerFactory.getLogger(OpenAiImageModel.class);
private OpenAiImageOptions defaultOptions;
@@ -54,11 +54,11 @@ public class OpenAiImageClient implements ImageClient {
public final RetryTemplate retryTemplate;
public OpenAiImageClient(OpenAiImageApi openAiImageApi) {
public OpenAiImageModel(OpenAiImageApi openAiImageApi) {
this(openAiImageApi, OpenAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public OpenAiImageClient(OpenAiImageApi openAiImageApi, OpenAiImageOptions defaultOptions,
public OpenAiImageModel(OpenAiImageApi openAiImageApi, OpenAiImageOptions defaultOptions,
RetryTemplate retryTemplate) {
Assert.notNull(openAiImageApi, "OpenAiImageApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");

View File

@@ -26,6 +26,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.ai.model.ModelDescription;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
@@ -113,7 +114,7 @@ public class OpenAiApi {
* - <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">GPT-4 and GPT-4 Turbo</a>
* - <a href="https://platform.openai.com/docs/models/gpt-3-5-turbo">GPT-3.5 Turbo</a>.
*/
public enum ChatModel {
public enum ChatModel implements ModelDescription {
/**
* Multimodal flagship model thats cheaper and faster than GPT-4 Turbo.
* Currently points to gpt-4o-2024-05-13.
@@ -199,6 +200,11 @@ public class OpenAiApi {
public String getValue() {
return value;
}
@Override
public String getModelName() {
return this.value;
}
}
/**

View File

@@ -16,10 +16,10 @@
package org.springframework.ai.openai.audio.speech;
import org.springframework.ai.model.ModelClient;
import org.springframework.ai.model.Model;
/**
* The {@link SpeechClient} interface provides a way to interact with the OpenAI
* The {@link SpeechModel} interface provides a way to interact with the OpenAI
* Text-to-Speech (TTS) API. It allows you to convert text input into lifelike spoken
* audio.
*
@@ -27,7 +27,7 @@ import org.springframework.ai.model.ModelClient;
* @since 1.0.0-M1
*/
@FunctionalInterface
public interface SpeechClient extends ModelClient<SpeechPrompt, SpeechResponse> {
public interface SpeechModel extends Model<SpeechPrompt, SpeechResponse> {
/**
* Generates spoken audio from the provided text message.

View File

@@ -16,11 +16,11 @@
package org.springframework.ai.openai.audio.speech;
import org.springframework.ai.model.StreamingModelClient;
import org.springframework.ai.model.StreamingModel;
import reactor.core.publisher.Flux;
/**
* The {@link StreamingSpeechClient} interface provides a way to interact with the OpenAI
* The {@link StreamingSpeechModel} interface provides a way to interact with the OpenAI
* Text-to-Speech (TTS) API using a streaming approach, allowing you to receive the
* generated audio in a real-time fashion.
*
@@ -28,7 +28,7 @@ import reactor.core.publisher.Flux;
* @since 1.0.0-M1
*/
@FunctionalInterface
public interface StreamingSpeechClient extends StreamingModelClient<SpeechPrompt, SpeechResponse> {
public interface StreamingSpeechModel extends StreamingModel<SpeechPrompt, SpeechResponse> {
/**
* Generates a stream of audio bytes from the provided text message.

View File

@@ -34,7 +34,7 @@ public class ChatCompletionRequestTests {
@Test
public void createRequestWithChatOptions() {
var client = new OpenAiChatClient(new OpenAiApi("TEST"),
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").withTemperature(66.6f).build());
var request = client.createRequest(new Prompt("Test message content"), false);
@@ -60,7 +60,7 @@ public class ChatCompletionRequestTests {
final String TOOL_FUNCTION_NAME = "CurrentWeather";
var client = new OpenAiChatClient(new OpenAiApi("TEST"),
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
OpenAiChatOptions.builder().withModel("DEFAULT_MODEL").build());
var request = client.createRequest(new Prompt("Test message content",
@@ -90,7 +90,7 @@ public class ChatCompletionRequestTests {
final String TOOL_FUNCTION_NAME = "CurrentWeather";
var client = new OpenAiChatClient(new OpenAiApi("TEST"),
var client = new OpenAiChatModel(new OpenAiApi("TEST"),
OpenAiChatOptions.builder()
.withModel("DEFAULT_MODEL")
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())

View File

@@ -15,7 +15,7 @@
*/
package org.springframework.ai.openai;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
@@ -51,33 +51,33 @@ public class OpenAiTestConfiguration {
}
@Bean
public OpenAiChatClient openAiChatClient(OpenAiApi api) {
OpenAiChatClient openAiChatClient = new OpenAiChatClient(api);
return openAiChatClient;
public OpenAiChatModel openAiChatModel(OpenAiApi api) {
OpenAiChatModel openAiChatModel = new OpenAiChatModel(api);
return openAiChatModel;
}
@Bean
public OpenAiAudioTranscriptionClient openAiTranscriptionClient(OpenAiAudioApi api) {
OpenAiAudioTranscriptionClient openAiTranscriptionClient = new OpenAiAudioTranscriptionClient(api);
return openAiTranscriptionClient;
public OpenAiAudioTranscriptionModel openAiTranscriptionModel(OpenAiAudioApi api) {
OpenAiAudioTranscriptionModel openAiTranscriptionModel = new OpenAiAudioTranscriptionModel(api);
return openAiTranscriptionModel;
}
@Bean
public OpenAiAudioSpeechClient openAiAudioSpeechClient(OpenAiAudioApi api) {
OpenAiAudioSpeechClient openAiAudioSpeechClient = new OpenAiAudioSpeechClient(api);
return openAiAudioSpeechClient;
public OpenAiAudioSpeechModel openAiAudioSpeechModel(OpenAiAudioApi api) {
OpenAiAudioSpeechModel openAiAudioSpeechModel = new OpenAiAudioSpeechModel(api);
return openAiAudioSpeechModel;
}
@Bean
public OpenAiImageClient openAiImageClient(OpenAiImageApi imageApi) {
OpenAiImageClient openAiImageClient = new OpenAiImageClient(imageApi);
// openAiImageClient.setModel("foobar");
return openAiImageClient;
public OpenAiImageModel openAiImageModel(OpenAiImageApi imageApi) {
OpenAiImageModel openAiImageModel = new OpenAiImageModel(imageApi);
// openAiImageModel.setModel("foobar");
return openAiImageModel;
}
@Bean
public EmbeddingClient openAiEmbeddingClient(OpenAiApi api) {
return new OpenAiEmbeddingClient(api);
public EmbeddingModel openAiEmbeddingModel(OpenAiApi api) {
return new OpenAiEmbeddingModel(api);
}
}

View File

@@ -34,7 +34,7 @@ public class TranscriptionRequestTests {
@Test
public void defaultOptions() {
var client = new OpenAiAudioTranscriptionClient(new OpenAiAudioApi("TEST"),
var client = new OpenAiAudioTranscriptionModel(new OpenAiAudioApi("TEST"),
OpenAiAudioTranscriptionOptions.builder()
.withModel("DEFAULT_MODEL")
.withResponseFormat(TranscriptResponseFormat.TEXT)
@@ -58,7 +58,7 @@ public class TranscriptionRequestTests {
@Test
public void runtimeOptions() {
var client = new OpenAiAudioTranscriptionClient(new OpenAiAudioApi("TEST"),
var client = new OpenAiAudioTranscriptionModel(new OpenAiAudioApi("TEST"),
OpenAiAudioTranscriptionOptions.builder()
.withModel("DEFAULT_MODEL")
.withResponseFormat(TranscriptResponseFormat.TEXT)

View File

@@ -26,9 +26,9 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiEmbeddingClient;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
@@ -58,16 +58,16 @@ public class AcmeIT extends AbstractIT {
private Resource systemBikePrompt;
@Autowired
private OpenAiEmbeddingClient embeddingClient;
private OpenAiEmbeddingModel embeddingModel;
@Autowired
private OpenAiChatClient chatClient;
private OpenAiChatModel chatModel;
@Test
void beanTest() {
assertThat(bikesResource).isNotNull();
assertThat(embeddingClient).isNotNull();
assertThat(chatClient).isNotNull();
assertThat(embeddingModel).isNotNull();
assertThat(chatModel).isNotNull();
}
// @Test
@@ -81,7 +81,7 @@ public class AcmeIT extends AbstractIT {
// Step 2 - Create embeddings and save to vector store
logger.info("Creating Embeddings...");
VectorStore vectorStore = new SimpleVectorStore(embeddingClient);
VectorStore vectorStore = new SimpleVectorStore(embeddingModel);
vectorStore.accept(textSplitter.apply(jsonReader.get()));
@@ -108,7 +108,7 @@ public class AcmeIT extends AbstractIT {
logger.info("Asking AI generative to reply to question.");
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
logger.info("AI responded.");
ChatResponse response = chatClient.call(prompt);
ChatResponse response = chatModel.call(prompt);
evaluateQuestionAndAnswer(userQuery, response, true);
}

View File

@@ -32,13 +32,13 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = OpenAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
class OpenAiSpeechClientIT extends AbstractIT {
class OpenAiSpeechModelIT extends AbstractIT {
private static final Float SPEED = 1.0f;
@Test
void shouldSuccessfullyStreamAudioBytesForEmptyMessage() {
Flux<byte[]> response = speechClient.stream("Today is a wonderful day to build something people love!");
Flux<byte[]> response = speechModel.stream("Today is a wonderful day to build something people love!");
assertThat(response).isNotNull();
assertThat(response.collectList().block()).isNotNull();
System.out.println(response.collectList().block());
@@ -46,7 +46,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
@Test
void shouldProduceAudioBytesDirectlyFromMessage() {
byte[] audioBytes = speechClient.call("Today is a wonderful day to build something people love!");
byte[] audioBytes = speechModel.call("Today is a wonderful day to build something people love!");
assertThat(audioBytes).hasSizeGreaterThan(0);
}
@@ -61,7 +61,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
.build();
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
SpeechResponse response = speechClient.call(speechPrompt);
SpeechResponse response = speechModel.call(speechPrompt);
byte[] audioBytes = response.getResult().getOutput();
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput()).isNotEmpty();
@@ -79,7 +79,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
.build();
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
SpeechResponse response = speechClient.call(speechPrompt);
SpeechResponse response = speechModel.call(speechPrompt);
OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata();
assertThat(metadata).isNotNull();
assertThat(metadata.getRateLimit()).isNotNull();
@@ -100,7 +100,7 @@ class OpenAiSpeechClientIT extends AbstractIT {
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
Flux<SpeechResponse> responseFlux = speechClient.stream(speechPrompt);
Flux<SpeechResponse> responseFlux = speechModel.stream(speechPrompt);
assertThat(responseFlux).isNotNull();
List<SpeechResponse> responses = responseFlux.collectList().block();
assertThat(responses).isNotNull();

View File

@@ -18,7 +18,7 @@ package org.springframework.ai.openai.audio.speech;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.springframework.ai.openai.OpenAiAudioSpeechClient;
import org.springframework.ai.openai.OpenAiAudioSpeechModel;
import org.springframework.ai.openai.OpenAiAudioSpeechOptions;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
@@ -45,15 +45,15 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat
/**
* @author Ahmed Yousri
*/
@RestClientTest(OpenAiSpeechClientWithSpeechResponseMetadataTests.Config.class)
public class OpenAiSpeechClientWithSpeechResponseMetadataTests {
@RestClientTest(OpenAiSpeechModelWithSpeechResponseMetadataTests.Config.class)
public class OpenAiSpeechModelWithSpeechResponseMetadataTests {
private static String TEST_API_KEY = "sk-1234567890";
private static final Float SPEED = 1.0f;
@Autowired
private OpenAiAudioSpeechClient openAiSpeechClient;
private OpenAiAudioSpeechModel openAiSpeechClient;
@Autowired
private MockRestServiceServer server;
@@ -121,8 +121,8 @@ public class OpenAiSpeechClientWithSpeechResponseMetadataTests {
static class Config {
@Bean
public OpenAiAudioSpeechClient openAiAudioSpeechClient(OpenAiAudioApi openAiAudioApi) {
return new OpenAiAudioSpeechClient(openAiAudioApi);
public OpenAiAudioSpeechModel openAiAudioSpeechClient(OpenAiAudioApi openAiAudioApi) {
return new OpenAiAudioSpeechModel(openAiAudioApi);
}
@Bean

View File

@@ -31,7 +31,7 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = OpenAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
class OpenAiTranscriptionClientIT extends AbstractIT {
class OpenAiTranscriptionModelIT extends AbstractIT {
@Value("classpath:/speech/jfk.flac")
private Resource audioFile;
@@ -43,7 +43,7 @@ class OpenAiTranscriptionClientIT extends AbstractIT {
.withTemperature(0f)
.build();
AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions);
AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest);
AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
}
@@ -59,7 +59,7 @@ class OpenAiTranscriptionClientIT extends AbstractIT {
.withResponseFormat(responseFormat)
.build();
AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions);
AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest);
AudioTranscriptionResponse response = transcriptionModel.call(transcriptionRequest);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
}

View File

@@ -21,7 +21,7 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.openai.OpenAiAudioTranscriptionClient;
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionMetadata;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata;
@@ -48,13 +48,13 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat
/**
* @author Michael Lavelle
*/
@RestClientTest(OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.Config.class)
public class OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests {
@RestClientTest(OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests.Config.class)
public class OpenAiTranscriptionModelWithTranscriptionResponseMetadataTests {
private static String TEST_API_KEY = "sk-1234567890";
@Autowired
private OpenAiAudioTranscriptionClient openAiTranscriptionClient;
private OpenAiAudioTranscriptionModel openAiTranscriptionClient;
@Autowired
private MockRestServiceServer server;
@@ -156,8 +156,8 @@ public class OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests {
}
@Bean
public OpenAiAudioTranscriptionClient openAiClient(OpenAiAudioApi openAiAudioApi) {
return new OpenAiAudioTranscriptionClient(openAiAudioApi);
public OpenAiAudioTranscriptionModel openAiClient(OpenAiAudioApi openAiAudioApi) {
return new OpenAiAudioTranscriptionModel(openAiAudioApi);
}
}

View File

@@ -18,7 +18,7 @@ package org.springframework.ai.openai.audio.transcription;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.ai.openai.OpenAiAudioTranscriptionClient;
import org.springframework.ai.openai.OpenAiAudioTranscriptionModel;
import org.springframework.core.io.Resource;
import static org.assertj.core.api.Assertions.assertThat;
@@ -33,18 +33,18 @@ import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit Tests for {@link TranscriptionClient}.
* Unit Tests for {@link TranscriptionModel}.
*
* @author Michael Lavelle
*/
class TranscriptionClientTests {
class TranscriptionModelTests {
@Test
void transcrbeRequestReturnsResponseCorrectly() {
Resource mockAudioFile = Mockito.mock(Resource.class);
OpenAiAudioTranscriptionClient mockClient = Mockito.mock(OpenAiAudioTranscriptionClient.class);
OpenAiAudioTranscriptionModel mockClient = Mockito.mock(OpenAiAudioTranscriptionModel.class);
String mockTranscription = "All your bases are belong to us";

View File

@@ -17,8 +17,8 @@ package org.springframework.ai.openai.chat;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -31,19 +31,10 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatModel;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.OpenAiApi;
@@ -51,7 +42,7 @@ import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.MimeTypeUtils;
@@ -65,75 +56,96 @@ class OpenAiChatClientIT extends AbstractIT {
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClientIT.class);
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
private Resource systemTextResource;
@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatClient.call(prompt);
// @formatter:off
ChatResponse response = ChatClient.builder(chatModel).build().prompt()
.system(s -> s.text(systemTextResource)
.param("name", "Bob")
.param("voice", "pirate"))
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
.call()
.chatResponse();
// @formatter:on
logger.info("" + response);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
}
@Test
void listOutputConverter() {
DefaultConversionService conversionService = new DefaultConversionService();
ListOutputConverter outputConverter = new ListOutputConverter(conversionService);
// @formatter:off
Collection<String> collection = ChatClient.builder(chatModel).build().prompt()
.user(u -> u.text("List five {subject}")
.param("subject", "ice cream flavors"))
.call()
.list(String.class);
// @formatter:on
String format = outputConverter.getFormat();
String template = """
List five {subject}
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.chatClient.call(prompt).getResult();
assertThat(collection).hasSize(5);
}
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
@Test
void listOutputConverter2() {
// @formatter:off
List<ActorsFilmsRecord> actorsFilms = ChatClient.builder(chatModel).build().prompt()
.user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
.call()
.single(new ParameterizedTypeReference<List<ActorsFilmsRecord>>() {
});
// @formatter:on
logger.info("" + actorsFilms);
assertThat(actorsFilms).hasSize(2);
}
@Test
void listOutputConverter3() {
// @formatter:off
Collection<ActorsFilmsRecord> actorsFilms = ChatClient.builder(chatModel).build().prompt()
.user("Generate the filmography of 5 movies for Tom Hanks and Bill Murray.")
.call()
.list(ActorsFilmsRecord.class);
// @formatter:on
logger.info("" + actorsFilms);
assertThat(actorsFilms).hasSize(2);
}
@Test
void mapOutputConverter() {
MapOutputConverter outputConverter = new MapOutputConverter();
// @formatter:off
Map<String, Object> result = ChatClient.builder(chatModel).build().prompt()
.user(u -> u.text("Provide me a List of {subject}")
.param("subject", "an array of numbers from 1 to 9 under they key name 'numbers'"))
.call()
.single(new ParameterizedTypeReference<Map<String, Object>>() {
});
// @formatter:on
String format = outputConverter.getFormat();
String template = """
Provide me a List of {subject}
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
}
@Test
void beanOutputConverter() {
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
// @formatter:off
ActorsFilms actorsFilms = ChatClient.builder(chatModel).build().prompt()
.user("Generate the filmography for a random actor.")
.call()
.single(ActorsFilms.class);
// @formatter:on
String format = outputConverter.getFormat();
String template = """
Generate the filmography for a random actor.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent());
logger.info("" + actorsFilms);
assertThat(actorsFilms.getActor()).isNotBlank();
}
record ActorsFilmsRecord(String actor, List<String> movies) {
@@ -142,18 +154,13 @@ class OpenAiChatClientIT extends AbstractIT {
@Test
void beanOutputConverterRecords() {
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
// @formatter:off
ActorsFilmsRecord actorsFilms = ChatClient.builder(chatModel).build().prompt()
.user("Generate the filmography of 5 movies for Tom Hanks.")
.call()
.single(ActorsFilmsRecord.class);
// @formatter:on
String format = outputConverter.getFormat();
String template = """
Generate the filmography of 5 movies for Tom Hanks.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
@@ -164,25 +171,25 @@ class OpenAiChatClientIT extends AbstractIT {
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
String format = outputConverter.getFormat();
String template = """
Generate the filmography of 5 movies for Tom Hanks.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
// @formatter:off
Flux<String> chatResponse = ChatClient.builder(chatModel)
.build()
.prompt()
.user(u -> u
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ "{format}")
.param("format", outputConverter.getFormat()))
.stream()
.content();
String generationTextFromStream = streamingChatClient.stream(prompt)
.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
String generationTextFromStream = chatResponse.collectList()
.block()
.stream()
.collect(Collectors.joining());
// @formatter:on
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
@@ -191,54 +198,33 @@ class OpenAiChatClientIT extends AbstractIT {
@Test
void functionCallTest() {
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
List<Message> messages = new ArrayList<>(List.of(userMessage));
var promptOptions = OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the weather in location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
// @formatter:off
String response = ChatClient.builder(chatModel).build().prompt()
.user(u -> u.text("What's the weather like in San Francisco, Tokyo, and Paris?"))
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.call()
.content();
// @formatter:on
logger.info("Response: {}", response);
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
assertThat(response).containsAnyOf("30.0", "30");
assertThat(response).containsAnyOf("10.0", "10");
assertThat(response).containsAnyOf("15.0", "15");
}
@Test
void streamFunctionCallTest() {
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
// @formatter:off
Flux<String> response = ChatClient.builder(chatModel).build().prompt()
.user("What's the weather like in San Francisco, Tokyo, and Paris?")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.stream()
.content();
// @formatter:on
List<Message> messages = new ArrayList<>(List.of(userMessage));
var promptOptions = OpenAiChatOptions.builder()
// .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the weather in location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
Flux<ChatResponse> response = streamingChatClient.stream(new Prompt(messages, promptOptions));
String content = response.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
String content = response.collectList().block().stream().collect(Collectors.joining());
logger.info("Response: {}", content);
assertThat(content).containsAnyOf("30.0", "30");
@@ -250,53 +236,62 @@ class OpenAiChatClientIT extends AbstractIT {
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
void multiModalityEmbeddedImage(String modelName) throws IOException {
var imageData = new ClassPathResource("/test.png");
// @formatter:off
String response = ChatClient.builder(chatModel).build().prompt()
// TODO consider adding model(...) method to ChatClient as a shortcut to
// OpenAiChatOptions.builder().withModel(modelName).build()
.options(OpenAiChatOptions.builder().withModel(modelName).build())
.user(u -> u.text("Explain what do you see on this picture?")
.media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png")))
.call()
.content();
// @formatter:on
var userMessage = new UserMessage("Explain what do you see on this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
var response = chatClient
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
logger.info(response);
assertThat(response).contains("bananas", "apple");
assertThat(response).containsAnyOf("bowl", "basket");
}
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
void multiModalityImageUrl(String modelName) throws IOException {
var userMessage = new UserMessage("Explain what do you see on this picture?", List
.of(new Media(MimeTypeUtils.IMAGE_PNG,
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
// TODO: add url method that wrapps the checked exception.
URL url = new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png");
ChatResponse response = chatClient
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
// @formatter:off
String response = ChatClient.builder(chatModel).build().prompt()
// TODO consider adding model(...) method to ChatClient as a shortcut to
// OpenAiChatOptions.builder().withModel(modelName).build()
.options(OpenAiChatOptions.builder().withModel(modelName).build())
.user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url))
.call()
.content();
// @formatter:on
logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
logger.info(response);
assertThat(response).contains("bananas", "apple");
assertThat(response).containsAnyOf("bowl", "basket");
}
@Test
void streamingMultiModalityImageUrl() throws IOException {
var userMessage = new UserMessage("Explain what do you see on this picture?", List
.of(new Media(MimeTypeUtils.IMAGE_PNG,
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
// TODO: add url method that wrapps the checked exception.
URL url = new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png");
Flux<ChatResponse> response = streamingChatClient.stream(new Prompt(List.of(userMessage),
OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build()));
// @formatter:off
Flux<String> response = ChatClient.builder(chatModel).build().prompt()
.options(OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue())
.build())
.user(u -> u.text("Explain what do you see on this picture?")
.media(MimeTypeUtils.IMAGE_PNG, url))
.stream()
.content();
// @formatter:on
String content = response.collectList().block().stream().collect(Collectors.joining());
String content = response.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
logger.info("Response: {}", content);
assertThat(content).contains("bananas", "apple");
assertThat(content).containsAnyOf("bowl", "basket");

View File

@@ -27,7 +27,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
@@ -41,14 +41,14 @@ import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
*/
@SpringBootTest(classes = OpenAiChatClient2IT.Config.class)
@SpringBootTest(classes = OpenAiChatModel2IT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiChatClient2IT {
public class OpenAiChatModel2IT {
private final Logger logger = LoggerFactory.getLogger(getClass());
@Autowired
private OpenAiChatClient openAiChatClient;
private OpenAiChatModel openAiChatModel;
@Test
void responseFormatTest() throws JsonMappingException, JsonProcessingException {
@@ -67,7 +67,7 @@ public class OpenAiChatClient2IT {
.withResponseFormat(new ChatCompletionRequest.ResponseFormat("json_object"))
.build());
ChatResponse response = this.openAiChatClient.call(prompt);
ChatResponse response = this.openAiChatModel.call(prompt);
assertThat(response).isNotNull();
@@ -99,8 +99,8 @@ public class OpenAiChatClient2IT {
}
@Bean
public OpenAiChatClient openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatClient(openAiApi);
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
return new OpenAiChatModel(openAiApi);
}
}

View File

@@ -0,0 +1,305 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.openai.chat;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.MimeTypeUtils;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = OpenAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
class OpenAiChatModelIT extends AbstractIT {
private static final Logger logger = LoggerFactory.getLogger(OpenAiChatModelIT.class);
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@Test
void roleTest() {
UserMessage userMessage = new UserMessage(
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard");
// needs fine tuning... evaluateQuestionAndAnswer(request, response, false);
}
@Test
void listOutputConverter() {
DefaultConversionService conversionService = new DefaultConversionService();
ListOutputConverter outputConverter = new ListOutputConverter(conversionService);
String format = outputConverter.getFormat();
String template = """
List five {subject}
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "ice cream flavors", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = this.chatModel.call(prompt).getResult();
List<String> list = outputConverter.convert(generation.getOutput().getContent());
assertThat(list).hasSize(5);
}
@Test
void mapOutputConverter() {
MapOutputConverter outputConverter = new MapOutputConverter();
String format = outputConverter.getFormat();
String template = """
Provide me a List of {subject}
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template,
Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatModel.call(prompt).getResult();
Map<String, Object> result = outputConverter.convert(generation.getOutput().getContent());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
}
@Test
void beanOutputConverter() {
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
String format = outputConverter.getFormat();
String template = """
Generate the filmography for a random actor.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatModel.call(prompt).getResult();
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent());
}
record ActorsFilmsRecord(String actor, List<String> movies) {
}
@Test
void beanOutputConverterRecords() {
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
String format = outputConverter.getFormat();
String template = """
Generate the filmography of 5 movies for Tom Hanks.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatModel.call(prompt).getResult();
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent());
logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
}
@Test
void beanStreamOutputConverterRecords() {
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
String format = outputConverter.getFormat();
String template = """
Generate the filmography of 5 movies for Tom Hanks.
{format}
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = streamingChatModel.stream(prompt)
.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
}
@Test
void functionCallTest() {
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
List<Message> messages = new ArrayList<>(List.of(userMessage));
var promptOptions = OpenAiChatOptions.builder()
.withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the weather in location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
ChatResponse response = chatModel.call(new Prompt(messages, promptOptions));
logger.info("Response: {}", response);
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("30.0", "30");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("10.0", "10");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("15.0", "15");
}
@Test
void streamFunctionCallTest() {
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
List<Message> messages = new ArrayList<>(List.of(userMessage));
var promptOptions = OpenAiChatOptions.builder()
// .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue())
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
.withName("getCurrentWeather")
.withDescription("Get the weather in location")
.withResponseConverter((response) -> "" + response.temp() + response.unit())
.build()))
.build();
Flux<ChatResponse> response = streamingChatModel.stream(new Prompt(messages, promptOptions));
String content = response.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
logger.info("Response: {}", content);
assertThat(content).containsAnyOf("30.0", "30");
assertThat(content).containsAnyOf("10.0", "10");
assertThat(content).containsAnyOf("15.0", "15");
}
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
void multiModalityEmbeddedImage(String modelName) throws IOException {
var imageData = new ClassPathResource("/test.png");
var userMessage = new UserMessage("Explain what do you see on this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));
var response = chatModel
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
}
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "gpt-4-vision-preview", "gpt-4o" })
void multiModalityImageUrl(String modelName) throws IOException {
var userMessage = new UserMessage("Explain what do you see on this picture?", List
.of(new Media(MimeTypeUtils.IMAGE_PNG,
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
ChatResponse response = chatModel
.call(new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withModel(modelName).build()));
logger.info(response.getResult().getOutput().getContent());
assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple");
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("bowl", "basket");
}
@Test
void streamingMultiModalityImageUrl() throws IOException {
var userMessage = new UserMessage("Explain what do you see on this picture?", List
.of(new Media(MimeTypeUtils.IMAGE_PNG,
new URL("https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))));
Flux<ChatResponse> response = streamingChatModel.stream(new Prompt(List.of(userMessage),
OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build()));
String content = response.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.collect(Collectors.joining());
logger.info("Response: {}", content);
assertThat(content).contains("bananas", "apple");
assertThat(content).containsAnyOf("bowl", "basket");
}
}

View File

@@ -39,10 +39,10 @@ import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(classes = OpenAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
class OpenAiChatClientTypeReferenceBeanOutputConverterIT extends AbstractIT {
class OpenAiChatModelTypeReferenceBeanOutputConverterIT extends AbstractIT {
private static final Logger logger = LoggerFactory
.getLogger(OpenAiChatClientTypeReferenceBeanOutputConverterIT.class);
.getLogger(OpenAiChatModelTypeReferenceBeanOutputConverterIT.class);
record ActorsFilmsRecord(String actor, List<String> movies) {
}
@@ -61,7 +61,7 @@ class OpenAiChatClientTypeReferenceBeanOutputConverterIT extends AbstractIT {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = chatClient.call(prompt).getResult();
Generation generation = chatModel.call(prompt).getResult();
List<ActorsFilmsRecord> actorsFilms = outputConverter.convert(generation.getOutput().getContent());
logger.info("" + actorsFilms);
@@ -87,7 +87,7 @@ class OpenAiChatClientTypeReferenceBeanOutputConverterIT extends AbstractIT {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
String generationTextFromStream = streamingChatClient.stream(prompt)
String generationTextFromStream = streamingChatModel.stream(prompt)
.collectList()
.block()
.stream()

Some files were not shown because too many files have changed in this diff Show More