ChatClient/DefaultChatClient refactoring

Now the ChatClient contains only the Spec interfaces while
   all implementations are moved to the DefaultChatClient
   and the DefaultChatModelBuilder classes.
This commit is contained in:
Christian Tzolov
2024-06-04 14:07:12 +02:00
parent b9844f9bf8
commit 6ad36b7653
4 changed files with 1005 additions and 687 deletions

View File

@@ -15,41 +15,26 @@
*/
package org.springframework.ai.chat.client;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.util.StringUtils;
/**
* Client to perform stateless requests to an AI Model, using a fluent API.
@@ -60,7 +45,7 @@ import org.springframework.util.StringUtils;
* @author Christian Tzolov
* @author Josh Long
* @author Arjen Poutsma
* @since 1.0.0 M1
* @since 1.0.0
*/
public interface ChatClient {
@@ -69,761 +54,200 @@ public interface ChatClient {
}
static Builder builder(ChatModel chatModel) {
return new Builder(chatModel);
return new DefaultChatClientBuilder(chatModel);
}
ChatClientRequest prompt();
ChatClientRequestSpec prompt();
ChatClientPromptRequest prompt(Prompt prompt);
ChatClientPromptRequestSpec prompt(Prompt prompt);
/**
* Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose
* settings are replicated from the default {@link ChatClientRequest} of this client.
* settings are replicated from the default {@link ChatClientRequestSpec} of this
* client.
*/
Builder mutate();
interface PromptSpec<T> {
interface PromptUserSpec {
T text(String text);
PromptUserSpec text(String text);
T text(Resource text, Charset charset);
PromptUserSpec text(Resource text, Charset charset);
T text(Resource text);
PromptUserSpec text(Resource text);
T params(Map<String, Object> p);
PromptUserSpec params(Map<String, Object> p);
T param(String k, Object v);
PromptUserSpec param(String k, Object v);
PromptUserSpec media(Media... media);
PromptUserSpec media(MimeType mimeType, URL url);
PromptUserSpec media(MimeType mimeType, Resource resource);
List<Media> media();
}
abstract class AbstractPromptSpec<T extends AbstractPromptSpec<T>> implements PromptSpec<T> {
interface PromptSystemSpec {
private String text = "";
PromptSystemSpec text(String text);
private final Map<String, Object> params = new HashMap<>();
PromptSystemSpec text(Resource text, Charset charset);
@Override
public T text(String text) {
this.text = text;
return self();
}
PromptSystemSpec text(Resource text);
@Override
public T text(Resource text, Charset charset) {
try {
this.text(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return self();
}
PromptSystemSpec params(Map<String, Object> p);
@Override
public T text(Resource text) {
this.text(text, Charset.defaultCharset());
return self();
}
@Override
public T param(String k, Object v) {
this.params.put(k, v);
return self();
}
@Override
public T params(Map<String, Object> p) {
this.params.putAll(p);
return self();
}
protected abstract T self();
protected String text() {
return this.text;
}
protected Map<String, Object> params() {
return this.params;
}
PromptSystemSpec param(String k, Object v);
}
class UserSpec extends AbstractPromptSpec<UserSpec> implements PromptSpec<UserSpec> {
interface AdvisorSpec {
private final List<Media> media = new ArrayList<>();
AdvisorSpec param(String k, Object v);
public UserSpec media(Media... media) {
this.media.addAll(Arrays.asList(media));
return self();
}
AdvisorSpec params(Map<String, Object> p);
public UserSpec media(MimeType mimeType, URL url) {
this.media.add(new Media(mimeType, url));
return self();
}
AdvisorSpec advisors(RequestResponseAdvisor... advisors);
public UserSpec media(MimeType mimeType, Resource resource) {
this.media.add(new Media(mimeType, resource));
return self();
}
protected List<Media> media() {
return this.media;
}
@Override
protected UserSpec self() {
return this;
}
AdvisorSpec advisors(List<RequestResponseAdvisor> advisors);
}
class SystemSpec extends AbstractPromptSpec<SystemSpec> implements PromptSpec<SystemSpec> {
interface CallResponseSpec {
@Override
protected SystemSpec self() {
return this;
}
<T> T entity(ParameterizedTypeReference<T> type);
<T> T entity(StructuredOutputConverter<T> structuredOutputConverter);
<T> T entity(Class<T> type);
ChatResponse chatResponse();
String content();
}
class ChatClientPromptRequest {
interface StreamResponseSpec {
private final ChatModel chatModel;
Flux<ChatResponse> chatResponse();
private final Prompt prompt;
public ChatClientPromptRequest(ChatModel chatModel, Prompt prompt) {
this.chatModel = chatModel;
this.prompt = prompt;
}
public ChatClientRequest.CallPromptResponseSpec call() {
return new ChatClientRequest.CallPromptResponseSpec(this.chatModel, this.prompt);
}
public ChatClientRequest.StreamPromptResponseSpec stream() {
return new ChatClientRequest.StreamPromptResponseSpec((StreamingChatModel) this.chatModel, this.prompt);
}
Flux<String> content();
}
class AdvisorSpec {
interface ChatClientPromptRequestSpec {
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
CallPromptResponseSpec call();
private final Map<String, Object> params = new HashMap<>();
public AdvisorSpec param(String k, Object v) {
this.params.put(k, v);
return this;
}
public AdvisorSpec params(Map<String, Object> p) {
this.params.putAll(p);
return this;
}
public AdvisorSpec advisors(RequestResponseAdvisor... advisors) {
this.advisors.addAll(List.of(advisors));
return this;
}
public AdvisorSpec advisors(List<RequestResponseAdvisor> advisors) {
this.advisors.addAll(advisors);
return this;
}
StreamPromptResponseSpec stream();
}
class ChatClientRequest {
interface CallPromptResponseSpec {
private final ChatModel chatModel;
String content();
private String userText = "";
List<String> contents();
private String systemText = "";
ChatResponse chatResponse();
private ChatOptions chatOptions;
}
private final List<Media> media = new ArrayList<>();
interface StreamPromptResponseSpec {
private final List<String> functionNames = new ArrayList<>();
Flux<ChatResponse> chatResponse();
private final List<FunctionCallback> functionCallbacks = new ArrayList<>();
public Flux<String> content();
private final List<Message> messages = new ArrayList<>();
}
private final Map<String, Object> userParams = new HashMap<>();
private final Map<String, Object> systemParams = new HashMap<>();
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
private final Map<String, Object> advisorParams = new HashMap<>();
/* copy constructor */
ChatClientRequest(ChatClientRequest ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
}
public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Object> userParams,
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions();
this.userText = userText;
this.userParams.putAll(userParams);
this.systemText = systemText;
this.systemParams.putAll(systemParams);
this.functionNames.addAll(functionNames);
this.functionCallbacks.addAll(functionCallbacks);
this.messages.addAll(messages);
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
}
interface ChatClientRequestSpec {
/**
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
* settings are replicated from this {@code ChatClientRequest}.
*/
public Builder mutate() {
Builder builder = ChatClient.builder(chatModel)
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
.defaultUser(u -> u.text(this.userText)
.params(this.userParams)
.media(this.media.toArray(new Media[this.media.size()])))
.defaultOptions(this.chatOptions)
.defaultFunctions(StringUtils.toStringArray(this.functionNames));
Builder mutate();
// workaround to set the missing fields.
builder.defaultRequest.messages.addAll(this.messages);
builder.defaultRequest.functionCallbacks.addAll(this.functionCallbacks);
ChatClientRequestSpec advisors(Consumer<AdvisorSpec> consumer);
return builder;
}
ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors);
public ChatClientRequest advisors(Consumer<AdvisorSpec> consumer) {
Assert.notNull(consumer, "the consumer must be non-null");
var as = new AdvisorSpec();
consumer.accept(as);
this.advisorParams.putAll(as.params);
this.advisors.addAll(as.advisors);
return this;
}
ChatClientRequestSpec advisors(List<RequestResponseAdvisor> advisors);
public ChatClientRequest advisors(RequestResponseAdvisor... advisors) {
Assert.notNull(advisors, "the advisors must be non-null");
this.advisors.addAll(List.of(advisors));
return this;
}
ChatClientRequestSpec messages(Message... messages);
public ChatClientRequest advisors(List<RequestResponseAdvisor> advisors) {
Assert.notNull(advisors, "the advisors must be non-null");
this.advisors.addAll(advisors);
return this;
}
ChatClientRequestSpec messages(List<Message> messages);
public ChatClientRequest messages(Message... messages) {
Assert.notNull(messages, "the messages must be non-null");
this.messages.addAll(List.of(messages));
return this;
}
<T extends ChatOptions> ChatClientRequestSpec options(T options);
public ChatClientRequest messages(List<Message> messages) {
Assert.notNull(messages, "the messages must be non-null");
this.messages.addAll(messages);
return this;
}
<I, O> ChatClientRequestSpec function(String name, String description,
java.util.function.Function<I, O> function);
public <T extends ChatOptions> ChatClientRequest options(T options) {
Assert.notNull(options, "the options must be non-null");
this.chatOptions = options;
return this;
}
ChatClientRequestSpec functions(String... functionBeanNames);
public <I, O> ChatClientRequest function(String name, String description,
java.util.function.Function<I, O> function) {
ChatClientRequestSpec system(String text);
Assert.hasText(name, "the name must be non-null and non-empty");
Assert.hasText(description, "the description must be non-null and non-empty");
Assert.notNull(function, "the function must be non-null");
ChatClientRequestSpec system(Resource textResource, Charset charset);
var fcw = FunctionCallbackWrapper.builder(function)
.withDescription(description)
.withName(name)
.withResponseConverter(Object::toString)
.build();
this.functionCallbacks.add(fcw);
return this;
}
ChatClientRequestSpec system(Resource text);
public ChatClientRequest functions(String... functionBeanNames) {
Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null");
this.functionNames.addAll(List.of(functionBeanNames));
return this;
}
ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer);
public ChatClientRequest system(String text) {
Assert.notNull(text, "the text must be non-null");
this.systemText = text;
return this;
}
ChatClientRequestSpec user(String text);
public ChatClientRequest system(Resource textResource, Charset charset) {
ChatClientRequestSpec user(Resource text, Charset charset);
Assert.notNull(textResource, "the text resource must be non-null");
Assert.notNull(charset, "the charset must be non-null");
ChatClientRequestSpec user(Resource text);
try {
this.systemText = textResource.getContentAsString(charset);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
public ChatClientRequest system(Resource text) {
Assert.notNull(text, "the text resource must be non-null");
return this.system(text, Charset.defaultCharset());
}
// ChatClientRequestSpec adviseOnRequest(ChatClientRequestSpec inputRequest,
// Map<String, Object> context);
public ChatClientRequest system(Consumer<SystemSpec> consumer) {
CallResponseSpec call();
Assert.notNull(consumer, "the consumer must be non-null");
var ss = new SystemSpec();
consumer.accept(ss);
this.systemText = StringUtils.hasText(ss.text()) ? ss.text() : this.systemText;
this.systemParams.putAll(ss.params());
return this;
}
public ChatClientRequest user(String text) {
Assert.notNull(text, "the text must be non-null");
this.userText = text;
return this;
}
public ChatClientRequest user(Resource text, Charset charset) {
Assert.notNull(text, "the text resource must be non-null");
Assert.notNull(charset, "the charset must be non-null");
try {
this.userText = text.getContentAsString(charset);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
public ChatClientRequest user(Resource text) {
Assert.notNull(text, "the text resource must be non-null");
return this.user(text, Charset.defaultCharset());
}
public ChatClientRequest user(Consumer<UserSpec> consumer) {
Assert.notNull(consumer, "the consumer must be non-null");
var us = new UserSpec();
consumer.accept(us);
this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText;
this.userParams.putAll(us.params());
this.media.addAll(us.media());
return this;
}
public static class StreamPromptResponseSpec {
private final Prompt prompt;
private final StreamingChatModel chatModel;
public StreamPromptResponseSpec(StreamingChatModel streamingChatModel, Prompt prompt) {
this.chatModel = streamingChatModel;
this.prompt = prompt;
}
public Flux<ChatResponse> chatResponse() {
return doGetFluxChatResponse(this.prompt);
}
private Flux<ChatResponse> doGetFluxChatResponse(Prompt prompt) {
return this.chatModel.stream(prompt);
}
public Flux<String> content() {
return doGetFluxChatResponse(this.prompt).map(r -> {
if (r.getResult() == null || r.getResult().getOutput() == null
|| r.getResult().getOutput().getContent() == null) {
return "";
}
return r.getResult().getOutput().getContent();
}).filter(v -> StringUtils.hasText(v));
}
}
public static class CallPromptResponseSpec {
private final ChatModel chatModel;
private final Prompt prompt;
public CallPromptResponseSpec(ChatModel chatModel, Prompt prompt) {
this.chatModel = chatModel;
this.prompt = prompt;
}
public String content() {
return doGetChatResponse(this.prompt).getResult().getOutput().getContent();
}
public List<String> contents() {
return doGetChatResponse(this.prompt).getResults()
.stream()
.map(r -> r.getOutput().getContent())
.toList();
}
public ChatResponse chatResponse() {
return doGetChatResponse(this.prompt);
}
private ChatResponse doGetChatResponse(Prompt prompt) {
return chatModel.call(prompt);
}
}
private static ChatClientRequest adviseOnRequest(ChatClientRequest inputRequest, Map<String, Object> context) {
ChatClientRequest advisedRequest = inputRequest;
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText,
inputRequest.systemText, inputRequest.chatOptions, inputRequest.media,
inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages,
inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors,
inputRequest.advisorParams);
// apply the advisors onRequest
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
for (RequestResponseAdvisor advisor : currentAdvisors) {
adviseRequest = advisor.adviseRequest(adviseRequest, context);
}
advisedRequest = new ChatClientRequest(adviseRequest.chatModel(), adviseRequest.userText(),
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
adviseRequest.advisorParams());
}
return advisedRequest;
}
public static class CallResponseSpec {
private final ChatClientRequest request;
private final ChatModel chatModel;
public CallResponseSpec(ChatModel chatModel, ChatClientRequest request) {
this.chatModel = chatModel;
this.request = request;
}
public <T> T entity(ParameterizedTypeReference<T> type) {
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
}
public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) {
return doSingleWithBeanOutputConverter(structuredOutputConverter);
}
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> boc) {
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
var stringResponse = chatResponse.getResult().getOutput().getContent();
return boc.convert(stringResponse);
}
public <T> T entity(Class<T> type) {
Assert.notNull(type, "the class must be non-null");
var boc = new BeanOutputConverter<T>(type);
return doSingleWithBeanOutputConverter(boc);
}
private ChatResponse doGetChatResponse() {
return this.doGetChatResponse(this.request, "");
}
private ChatResponse doGetChatResponse(ChatClientRequest inputRequest, String formatParam) {
Map<String, Object> context = new ConcurrentHashMap<>();
context.putAll(inputRequest.advisorParams);
ChatClientRequest advisedRequest = adviseOnRequest(inputRequest, context);
var processedUserText = StringUtils.hasText(formatParam)
? advisedRequest.userText + System.lineSeparator() + "{spring_ai_soc_format}"
: advisedRequest.userText;
Map<String, Object> userParams = new HashMap<>(advisedRequest.userParams);
if (StringUtils.hasText(formatParam)) {
userParams.put("spring_ai_soc_format", formatParam);
}
var messages = new ArrayList<Message>(advisedRequest.messages);
var textsAreValid = (StringUtils.hasText(processedUserText)
|| StringUtils.hasText(advisedRequest.systemText));
if (textsAreValid) {
if (StringUtils.hasText(advisedRequest.systemText) || !advisedRequest.systemParams.isEmpty()) {
var systemMessage = new SystemMessage(
new PromptTemplate(advisedRequest.systemText, advisedRequest.systemParams).render());
messages.add(systemMessage);
}
UserMessage userMessage = null;
if (!CollectionUtils.isEmpty(userParams)) {
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
advisedRequest.media);
}
else {
userMessage = new UserMessage(processedUserText, advisedRequest.media);
}
messages.add(userMessage);
}
if (advisedRequest.chatOptions instanceof FunctionCallingOptions functionCallingOptions) {
if (!advisedRequest.functionNames.isEmpty()) {
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames));
}
if (!advisedRequest.functionCallbacks.isEmpty()) {
functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks);
}
}
var prompt = new Prompt(messages, advisedRequest.chatOptions);
var chatResponse = this.chatModel.call(prompt);
ChatResponse advisedResponse = chatResponse;
// apply the advisors on response
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
for (RequestResponseAdvisor advisor : currentAdvisors) {
advisedResponse = advisor.adviseResponse(advisedResponse, context);
}
}
return advisedResponse;
}
public ChatResponse chatResponse() {
return doGetChatResponse();
}
public String content() {
return doGetChatResponse().getResult().getOutput().getContent();
}
}
public static class StreamResponseSpec {
private final ChatClientRequest request;
private final StreamingChatModel chatModel;
public StreamResponseSpec(StreamingChatModel streamingChatModel, ChatClientRequest request) {
this.chatModel = streamingChatModel;
this.request = request;
}
private Flux<ChatResponse> doGetFluxChatResponse(ChatClientRequest inputRequest) {
Map<String, Object> context = new ConcurrentHashMap<>();
context.putAll(inputRequest.advisorParams);
ChatClientRequest advisedRequest = adviseOnRequest(inputRequest, context);
String processedUserText = advisedRequest.userText;
Map<String, Object> userParams = new HashMap<>(advisedRequest.userParams);
var messages = new ArrayList<Message>(advisedRequest.messages);
var textsAreValid = (StringUtils.hasText(processedUserText)
|| StringUtils.hasText(advisedRequest.systemText));
if (textsAreValid) {
UserMessage userMessage = null;
if (!CollectionUtils.isEmpty(userParams)) {
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
advisedRequest.media);
}
else {
userMessage = new UserMessage(processedUserText, advisedRequest.media);
}
if (StringUtils.hasText(advisedRequest.systemText) || !advisedRequest.systemParams.isEmpty()) {
var systemMessage = new SystemMessage(
new PromptTemplate(advisedRequest.systemText, advisedRequest.systemParams).render());
messages.add(systemMessage);
}
messages.add(userMessage);
}
if (advisedRequest.chatOptions instanceof
FunctionCallingOptions functionCallingOptions) {
if (!advisedRequest.functionNames.isEmpty()) {
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames));
}
if (!advisedRequest.functionCallbacks.isEmpty()) {
functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks);
}
}
var prompt = new Prompt(messages, advisedRequest.chatOptions);
var fluxChatResponse = this.chatModel.stream(prompt);
Flux<ChatResponse> advisedResponse = fluxChatResponse;
// apply the advisors on response
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
for (RequestResponseAdvisor advisor : currentAdvisors) {
advisedResponse = advisor.adviseResponse(advisedResponse, context);
}
}
return advisedResponse;
}
public Flux<ChatResponse> chatResponse() {
return doGetFluxChatResponse(this.request);
}
public Flux<String> content() {
return doGetFluxChatResponse(this.request).map(r -> {
if (r.getResult() == null || r.getResult().getOutput() == null
|| r.getResult().getOutput().getContent() == null) {
return "";
}
return r.getResult().getOutput().getContent();
}).filter(v -> StringUtils.hasText(v));
}
}
public CallResponseSpec call() {
return new CallResponseSpec(this.chatModel, this);
}
public StreamResponseSpec stream() {
return new StreamResponseSpec((StreamingChatModel) this.chatModel, this);
}
StreamResponseSpec stream();
}
class Builder {
/**
* A mutable builder for creating a {@link ChatClient}.
*/
interface Builder {
private final ChatClientRequest defaultRequest;
Builder defaultAdvisors(RequestResponseAdvisor... advisor);
private final ChatModel chatModel;
Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer);
Builder(ChatModel chatModel) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
this.chatModel = chatModel;
this.defaultRequest = new ChatClientRequest(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of());
}
Builder defaultAdvisors(List<RequestResponseAdvisor> advisors);
public Builder defaultAdvisors(RequestResponseAdvisor... advisor) {
this.defaultRequest.advisors(advisor);
return this;
}
ChatClient build();
public Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer) {
this.defaultRequest.advisors(advisorSpecConsumer);
return this;
}
Builder defaultOptions(ChatOptions chatOptions);
public Builder defaultAdvisors(List<RequestResponseAdvisor> advisors) {
this.defaultRequest.advisors(advisors);
return this;
}
Builder defaultUser(String text);
public ChatClient build() {
return new DefaultChatClient(this.chatModel, this.defaultRequest);
}
Builder defaultUser(Resource text, Charset charset);
public Builder defaultOptions(ChatOptions chatOptions) {
this.defaultRequest.options(chatOptions);
return this;
}
Builder defaultUser(Resource text);
public Builder defaultUser(String text) {
this.defaultRequest.user(text);
return this;
}
Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer);
public Builder defaultUser(Resource text, Charset charset) {
try {
this.defaultRequest.user(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
Builder defaultSystem(String text);
public Builder defaultUser(Resource text) {
return this.defaultUser(text, Charset.defaultCharset());
}
Builder defaultSystem(Resource text, Charset charset);
public Builder defaultUser(Consumer<UserSpec> userSpecConsumer) {
this.defaultRequest.user(userSpecConsumer);
return this;
}
Builder defaultSystem(Resource text);
public Builder defaultSystem(String text) {
this.defaultRequest.system(text);
return this;
}
Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);
public Builder defaultSystem(Resource text, Charset charset) {
try {
this.defaultRequest.system(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
<I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function);
public Builder defaultSystem(Resource text) {
return this.defaultSystem(text, Charset.defaultCharset());
}
public Builder defaultSystem(Consumer<SystemSpec> systemSpecConsumer) {
this.defaultRequest.system(systemSpecConsumer);
return this;
}
public <I, O> Builder defaultFunction(String name, String description,
java.util.function.Function<I, O> function) {
this.defaultRequest.function(name, description, function);
return this;
}
public Builder defaultFunctions(String... functionNames) {
this.defaultRequest.functions(functionNames);
return this;
}
Builder defaultFunctions(String... functionNames);
}

View File

@@ -1,42 +1,74 @@
package org.springframework.ai.chat.client;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.util.StringUtils;
/**
* The default implementation of {@link ChatClient} as created by the
* {@link ChatClient.Builder#build()} } method.
* {@link Builder#build()} } method.
*
* @author Mark Pollack
* @author Christian Tzolov
* @author Josh Long
* @author Arjen Poutsma
* @since 1.0.0 M1
* @since 1.0.0
*/
class DefaultChatClient implements ChatClient {
public class DefaultChatClient implements ChatClient {
private final ChatModel chatModel;
private final ChatClientRequest defaultChatClientRequest;
private final DefaultChatClientRequestSpec defaultChatClientRequest;
public DefaultChatClient(ChatModel chatModel, ChatClientRequest defaultChatClientRequest) {
public DefaultChatClient(ChatModel chatModel, DefaultChatClientRequestSpec defaultChatClientRequest) {
this.chatModel = chatModel;
this.defaultChatClientRequest = defaultChatClientRequest;
}
@Override
public ChatClientRequest prompt() {
return new ChatClientRequest(this.defaultChatClientRequest);
public ChatClientRequestSpec prompt() {
return new DefaultChatClientRequestSpec(this.defaultChatClientRequest);
}
@Override
public ChatClientPromptRequest prompt(Prompt prompt) {
return new ChatClientPromptRequest(this.chatModel, prompt);
public ChatClientPromptRequestSpec prompt(Prompt prompt) {
return new DefaultChatClientPromptRequestSpec(this.chatModel, prompt);
}
/**
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient} whose
* settings are replicated from this {@code ChatClientRequest}.
*/
@Override
@@ -44,6 +76,732 @@ class DefaultChatClient implements ChatClient {
return this.defaultChatClientRequest.mutate();
}
public static class DefaultPromptUserSpec implements PromptUserSpec {
private String text = "";
private final Map<String, Object> params = new HashMap<>();
private final List<Media> media = new ArrayList<>();
@Override
public PromptUserSpec media(Media... media) {
this.media.addAll(Arrays.asList(media));
return this;
}
@Override
public PromptUserSpec media(MimeType mimeType, URL url) {
this.media.add(new Media(mimeType, url));
return this;
}
@Override
public PromptUserSpec media(MimeType mimeType, Resource resource) {
this.media.add(new Media(mimeType, resource));
return this;
}
@Override
public PromptUserSpec text(String text) {
this.text = text;
return this;
}
@Override
public PromptUserSpec text(Resource text, Charset charset) {
try {
this.text(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
@Override
public PromptUserSpec text(Resource text) {
this.text(text, Charset.defaultCharset());
return this;
}
@Override
public PromptUserSpec param(String k, Object v) {
this.params.put(k, v);
return this;
}
@Override
public PromptUserSpec params(Map<String, Object> p) {
this.params.putAll(p);
return this;
}
protected String text() {
return this.text;
}
protected Map<String, Object> params() {
return this.params;
}
@Override
public List<Media> media() {
return this.media;
}
}
public static class DefaultPromptSystemSpec implements PromptSystemSpec {
private String text = "";
private final Map<String, Object> params = new HashMap<>();
@Override
public PromptSystemSpec text(String text) {
this.text = text;
return this;
}
@Override
public PromptSystemSpec text(Resource text, Charset charset) {
try {
this.text(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
@Override
public PromptSystemSpec text(Resource text) {
this.text(text, Charset.defaultCharset());
return this;
}
@Override
public PromptSystemSpec param(String k, Object v) {
this.params.put(k, v);
return this;
}
@Override
public PromptSystemSpec params(Map<String, Object> p) {
this.params.putAll(p);
return this;
}
protected String text() {
return this.text;
}
protected Map<String, Object> params() {
return this.params;
}
}
public static class DefaultAdvisorSpec implements AdvisorSpec {
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
private final Map<String, Object> params = new HashMap<>();
public AdvisorSpec param(String k, Object v) {
this.params.put(k, v);
return this;
}
public AdvisorSpec params(Map<String, Object> p) {
this.params.putAll(p);
return this;
}
public AdvisorSpec advisors(RequestResponseAdvisor... advisors) {
this.advisors.addAll(List.of(advisors));
return this;
}
public AdvisorSpec advisors(List<RequestResponseAdvisor> advisors) {
this.advisors.addAll(advisors);
return this;
}
public List<RequestResponseAdvisor> getAdvisors() {
return advisors;
}
public Map<String, Object> getParams() {
return params;
}
}
public static class DefaultCallResponseSpec implements CallResponseSpec {
private final DefaultChatClientRequestSpec request;
private final ChatModel chatModel;
public DefaultCallResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) {
this.chatModel = chatModel;
this.request = request;
}
public <T> T entity(ParameterizedTypeReference<T> type) {
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
}
public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) {
return doSingleWithBeanOutputConverter(structuredOutputConverter);
}
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> boc) {
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
var stringResponse = chatResponse.getResult().getOutput().getContent();
return boc.convert(stringResponse);
}
public <T> T entity(Class<T> type) {
Assert.notNull(type, "the class must be non-null");
var boc = new BeanOutputConverter<T>(type);
return doSingleWithBeanOutputConverter(boc);
}
private ChatResponse doGetChatResponse() {
return this.doGetChatResponse(this.request, "");
}
private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest, String formatParam) {
Map<String, Object> context = new ConcurrentHashMap<>();
context.putAll(inputRequest.getAdvisorParams());
DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest,
context);
var processedUserText = StringUtils.hasText(formatParam)
? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}"
: advisedRequest.getUserText();
Map<String, Object> userParams = new HashMap<>(advisedRequest.getUserParams());
if (StringUtils.hasText(formatParam)) {
userParams.put("spring_ai_soc_format", formatParam);
}
var messages = new ArrayList<Message>(advisedRequest.getMessages());
var textsAreValid = (StringUtils.hasText(processedUserText)
|| StringUtils.hasText(advisedRequest.getSystemText()));
if (textsAreValid) {
if (StringUtils.hasText(advisedRequest.getSystemText())
|| !advisedRequest.getSystemParams().isEmpty()) {
var systemMessage = new SystemMessage(
new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams())
.render());
messages.add(systemMessage);
}
UserMessage userMessage = null;
if (!CollectionUtils.isEmpty(userParams)) {
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
advisedRequest.getMedia());
}
else {
userMessage = new UserMessage(processedUserText, advisedRequest.getMedia());
}
messages.add(userMessage);
}
if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) {
if (!advisedRequest.getFunctionNames().isEmpty()) {
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames()));
}
if (!advisedRequest.getFunctionCallbacks().isEmpty()) {
functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks());
}
}
var prompt = new Prompt(messages, advisedRequest.getChatOptions());
var chatResponse = this.chatModel.call(prompt);
ChatResponse advisedResponse = chatResponse;
// apply the advisors on response
if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) {
var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors());
for (RequestResponseAdvisor advisor : currentAdvisors) {
advisedResponse = advisor.adviseResponse(advisedResponse, context);
}
}
return advisedResponse;
}
public ChatResponse chatResponse() {
return doGetChatResponse();
}
public String content() {
return doGetChatResponse().getResult().getOutput().getContent();
}
}
public static class DefaultStreamResponseSpec implements StreamResponseSpec {
private final DefaultChatClientRequestSpec request;
private final ChatModel chatModel;
public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) {
this.chatModel = chatModel;
this.request = request;
}
private Flux<ChatResponse> doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) {
Map<String, Object> context = new ConcurrentHashMap<>();
context.putAll(inputRequest.getAdvisorParams());
DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest,
context);
String processedUserText = advisedRequest.getUserText();
Map<String, Object> userParams = new HashMap<>(advisedRequest.getUserParams());
var messages = new ArrayList<Message>(advisedRequest.getMessages());
var textsAreValid = (StringUtils.hasText(processedUserText)
|| StringUtils.hasText(advisedRequest.getSystemText()));
if (textsAreValid) {
UserMessage userMessage = null;
if (!CollectionUtils.isEmpty(userParams)) {
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
advisedRequest.getMedia());
}
else {
userMessage = new UserMessage(processedUserText, advisedRequest.getMedia());
}
if (StringUtils.hasText(advisedRequest.getSystemText())
|| !advisedRequest.getSystemParams().isEmpty()) {
var systemMessage = new SystemMessage(
new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams())
.render());
messages.add(systemMessage);
}
messages.add(userMessage);
}
if (advisedRequest.getChatOptions() instanceof
FunctionCallingOptions functionCallingOptions) {
if (!advisedRequest.getFunctionNames().isEmpty()) {
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames()));
}
if (!advisedRequest.getFunctionCallbacks().isEmpty()) {
functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks());
}
}
var prompt = new Prompt(messages, advisedRequest.getChatOptions());
var fluxChatResponse = this.chatModel.stream(prompt);
Flux<ChatResponse> advisedResponse = fluxChatResponse;
// apply the advisors on response
if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) {
var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors());
for (RequestResponseAdvisor advisor : currentAdvisors) {
advisedResponse = advisor.adviseResponse(advisedResponse, context);
}
}
return advisedResponse;
}
public Flux<ChatResponse> chatResponse() {
return doGetFluxChatResponse(this.request);
}
public Flux<String> content() {
return doGetFluxChatResponse(this.request).map(r -> {
if (r.getResult() == null || r.getResult().getOutput() == null
|| r.getResult().getOutput().getContent() == null) {
return "";
}
return r.getResult().getOutput().getContent();
}).filter(v -> StringUtils.hasText(v));
}
}
public static class DefaultChatClientRequestSpec implements ChatClientRequestSpec {
private final ChatModel chatModel;
private String userText = "";
private String systemText = "";
private ChatOptions chatOptions;
private final List<Media> media = new ArrayList<>();
private final List<String> functionNames = new ArrayList<>();
private final List<FunctionCallback> functionCallbacks = new ArrayList<>();
private final List<Message> messages = new ArrayList<>();
private final Map<String, Object> userParams = new HashMap<>();
private final Map<String, Object> systemParams = new HashMap<>();
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
private final Map<String, Object> advisorParams = new HashMap<>();
public String getUserText() {
return userText;
}
public Map<String, Object> getUserParams() {
return userParams;
}
public String getSystemText() {
return systemText;
}
public Map<String, Object> getSystemParams() {
return systemParams;
}
public ChatOptions getChatOptions() {
return chatOptions;
}
public List<RequestResponseAdvisor> getAdvisors() {
return advisors;
}
public Map<String, Object> getAdvisorParams() {
return advisorParams;
}
public List<Message> getMessages() {
return messages;
}
public List<Media> getMedia() {
return media;
}
public List<String> getFunctionNames() {
return this.functionNames;
}
public List<FunctionCallback> getFunctionCallbacks() {
return functionCallbacks;
}
/* copy constructor */
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
}
public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<String, Object> userParams,
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions();
this.userText = userText;
this.userParams.putAll(userParams);
this.systemText = systemText;
this.systemParams.putAll(systemParams);
this.functionNames.addAll(functionNames);
this.functionCallbacks.addAll(functionCallbacks);
this.messages.addAll(messages);
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
}
/**
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose
* settings are replicated from this {@code ChatClientRequest}.
*/
public Builder mutate() {
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient.builder(chatModel)
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
.defaultUser(u -> u.text(this.userText)
.params(this.userParams)
.media(this.media.toArray(new Media[this.media.size()])))
.defaultOptions(this.chatOptions)
.defaultFunctions(StringUtils.toStringArray(this.functionNames));
// workaround to set the missing fields.
builder.defaultRequest.getMessages().addAll(this.messages);
builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks);
return builder;
}
public ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer) {
Assert.notNull(consumer, "the consumer must be non-null");
var as = new DefaultAdvisorSpec();
consumer.accept(as);
this.advisorParams.putAll(as.getParams());
this.advisors.addAll(as.getAdvisors());
return this;
}
public ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors) {
Assert.notNull(advisors, "the advisors must be non-null");
this.advisors.addAll(List.of(advisors));
return this;
}
public ChatClientRequestSpec advisors(List<RequestResponseAdvisor> advisors) {
Assert.notNull(advisors, "the advisors must be non-null");
this.advisors.addAll(advisors);
return this;
}
public ChatClientRequestSpec messages(Message... messages) {
Assert.notNull(messages, "the messages must be non-null");
this.messages.addAll(List.of(messages));
return this;
}
public ChatClientRequestSpec messages(List<Message> messages) {
Assert.notNull(messages, "the messages must be non-null");
this.messages.addAll(messages);
return this;
}
public <T extends ChatOptions> ChatClientRequestSpec options(T options) {
Assert.notNull(options, "the options must be non-null");
this.chatOptions = options;
return this;
}
public <I, O> ChatClientRequestSpec function(String name, String description,
java.util.function.Function<I, O> function) {
Assert.hasText(name, "the name must be non-null and non-empty");
Assert.hasText(description, "the description must be non-null and non-empty");
Assert.notNull(function, "the function must be non-null");
var fcw = FunctionCallbackWrapper.builder(function)
.withDescription(description)
.withName(name)
.withResponseConverter(Object::toString)
.build();
this.functionCallbacks.add(fcw);
return this;
}
public ChatClientRequestSpec functions(String... functionBeanNames) {
Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null");
this.functionNames.addAll(List.of(functionBeanNames));
return this;
}
public ChatClientRequestSpec system(String text) {
Assert.notNull(text, "the text must be non-null");
this.systemText = text;
return this;
}
public ChatClientRequestSpec system(Resource textResource, Charset charset) {
Assert.notNull(textResource, "the text resource must be non-null");
Assert.notNull(charset, "the charset must be non-null");
try {
this.systemText = textResource.getContentAsString(charset);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
public ChatClientRequestSpec system(Resource text) {
Assert.notNull(text, "the text resource must be non-null");
return this.system(text, Charset.defaultCharset());
}
public ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer) {
Assert.notNull(consumer, "the consumer must be non-null");
var ss = new DefaultPromptSystemSpec();
consumer.accept(ss);
this.systemText = StringUtils.hasText(ss.text()) ? ss.text() : this.systemText;
this.systemParams.putAll(ss.params());
return this;
}
public ChatClientRequestSpec user(String text) {
Assert.notNull(text, "the text must be non-null");
this.userText = text;
return this;
}
public ChatClientRequestSpec user(Resource text, Charset charset) {
Assert.notNull(text, "the text resource must be non-null");
Assert.notNull(charset, "the charset must be non-null");
try {
this.userText = text.getContentAsString(charset);
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
public ChatClientRequestSpec user(Resource text) {
Assert.notNull(text, "the text resource must be non-null");
return this.user(text, Charset.defaultCharset());
}
public ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer) {
Assert.notNull(consumer, "the consumer must be non-null");
var us = new DefaultPromptUserSpec();
consumer.accept(us);
this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText;
this.userParams.putAll(us.params());
this.media.addAll(us.media());
return this;
}
public CallResponseSpec call() {
return new DefaultCallResponseSpec(chatModel, this);
}
public StreamResponseSpec stream() {
return new DefaultStreamResponseSpec(chatModel, this);
}
public static DefaultChatClientRequestSpec adviseOnRequest(DefaultChatClientRequestSpec inputRequest,
Map<String, Object> context) {
DefaultChatClientRequestSpec advisedRequest = inputRequest;
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText,
inputRequest.systemText, inputRequest.chatOptions, inputRequest.media,
inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages,
inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors,
inputRequest.advisorParams);
// apply the advisors onRequest
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
for (RequestResponseAdvisor advisor : currentAdvisors) {
adviseRequest = advisor.adviseRequest(adviseRequest, context);
}
advisedRequest = new DefaultChatClientRequestSpec(adviseRequest.chatModel(), adviseRequest.userText(),
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
adviseRequest.advisorParams());
}
return advisedRequest;
}
}
// Prompt
public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec {
private final ChatModel chatModel;
private final Prompt prompt;
public DefaultCallPromptResponseSpec(ChatModel chatModel, Prompt prompt) {
this.chatModel = chatModel;
this.prompt = prompt;
}
public String content() {
return doGetChatResponse(this.prompt).getResult().getOutput().getContent();
}
public List<String> contents() {
return doGetChatResponse(this.prompt).getResults().stream().map(r -> r.getOutput().getContent()).toList();
}
public ChatResponse chatResponse() {
return doGetChatResponse(this.prompt);
}
private ChatResponse doGetChatResponse(Prompt prompt) {
return chatModel.call(prompt);
}
}
public static class DefaultStreamPromptResponseSpec implements StreamPromptResponseSpec {
private final Prompt prompt;
private final StreamingChatModel chatModel;
public DefaultStreamPromptResponseSpec(StreamingChatModel streamingChatModel, Prompt prompt) {
this.chatModel = streamingChatModel;
this.prompt = prompt;
}
public Flux<ChatResponse> chatResponse() {
return doGetFluxChatResponse(this.prompt);
}
private Flux<ChatResponse> doGetFluxChatResponse(Prompt prompt) {
return this.chatModel.stream(prompt);
}
public Flux<String> content() {
return doGetFluxChatResponse(this.prompt).map(r -> {
if (r.getResult() == null || r.getResult().getOutput() == null
|| r.getResult().getOutput().getContent() == null) {
return "";
}
return r.getResult().getOutput().getContent();
}).filter(v -> StringUtils.hasText(v));
}
}
public static class DefaultChatClientPromptRequestSpec implements ChatClientPromptRequestSpec {
private final ChatModel chatModel;
private final Prompt prompt;
public DefaultChatClientPromptRequestSpec(ChatModel chatModel, Prompt prompt) {
this.chatModel = chatModel;
this.prompt = prompt;
}
public CallPromptResponseSpec call() {
return new DefaultCallPromptResponseSpec(this.chatModel, this.prompt);
}
public StreamPromptResponseSpec stream() {
return new DefaultStreamPromptResponseSpec((StreamingChatModel) this.chatModel, this.prompt);
}
}
/**
* use the new fluid DSL starting in {@link #prompt()}
* @param prompt the {@link Prompt prompt} object
@@ -55,4 +813,4 @@ class DefaultChatClient implements ChatClient {
return this.chatModel.call(prompt);
}
}
}

View File

@@ -0,0 +1,137 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.ai.chat.client.ChatClient.Builder;
import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec;
import org.springframework.ai.chat.client.ChatClient.PromptUserSpec;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
/**
* @author Mark Pollack
* @author Christian Tzolov
* @author Josh Long
* @author Arjen Poutsma
* @since 1.0.0
*
*/
public class DefaultChatClientBuilder implements Builder {
protected final DefaultChatClientRequestSpec defaultRequest;
private final ChatModel chatModel;
public DefaultChatClientBuilder(ChatModel chatModel) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
this.chatModel = chatModel;
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(),
List.of(), List.of(), List.of(), null, List.of(), Map.of());
}
public ChatClient build() {
return new DefaultChatClient(this.chatModel, this.defaultRequest);
}
public Builder defaultAdvisors(RequestResponseAdvisor... advisor) {
this.defaultRequest.advisors(advisor);
return this;
}
public Builder defaultAdvisors(Consumer<ChatClient.AdvisorSpec> advisorSpecConsumer) {
this.defaultRequest.advisors(advisorSpecConsumer);
return this;
}
public Builder defaultAdvisors(List<RequestResponseAdvisor> advisors) {
this.defaultRequest.advisors(advisors);
return this;
}
public Builder defaultOptions(ChatOptions chatOptions) {
this.defaultRequest.options(chatOptions);
return this;
}
public Builder defaultUser(String text) {
this.defaultRequest.user(text);
return this;
}
public Builder defaultUser(Resource text, Charset charset) {
try {
this.defaultRequest.user(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
public Builder defaultUser(Resource text) {
return this.defaultUser(text, Charset.defaultCharset());
}
public Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer) {
this.defaultRequest.user(userSpecConsumer);
return this;
}
public Builder defaultSystem(String text) {
this.defaultRequest.system(text);
return this;
}
public Builder defaultSystem(Resource text, Charset charset) {
try {
this.defaultRequest.system(text.getContentAsString(charset));
}
catch (IOException e) {
throw new RuntimeException(e);
}
return this;
}
public Builder defaultSystem(Resource text) {
return this.defaultSystem(text, Charset.defaultCharset());
}
public Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer) {
this.defaultRequest.system(systemSpecConsumer);
return this;
}
public <I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function) {
this.defaultRequest.function(name, description, function);
return this;
}
public Builder defaultFunctions(String... functionNames) {
this.defaultRequest.functions(functionNames);
return this;
}
}

View File

@@ -20,7 +20,6 @@ import java.util.Map;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.ChatClient.ChatClientRequest;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
@@ -31,7 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt;
* chain of advisors with chared execution context.
*
* @author Christian Tzolov
* @since 1.0.0 M1
* @since 1.0.0
*/
public interface RequestResponseAdvisor {