diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index 9e457ec5c..add0d0416 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -15,41 +15,26 @@ */ package org.springframework.ai.chat.client; -import java.io.IOException; import java.net.URL; import java.nio.charset.Charset; -import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.Media; import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; -import org.springframework.ai.chat.model.StreamingChatModel; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackWrapper; -import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; -import org.springframework.util.StringUtils; /** * Client to perform stateless requests to an AI Model, using a fluent API. @@ -60,7 +45,7 @@ import org.springframework.util.StringUtils; * @author Christian Tzolov * @author Josh Long * @author Arjen Poutsma - * @since 1.0.0 M1 + * @since 1.0.0 */ public interface ChatClient { @@ -69,761 +54,200 @@ public interface ChatClient { } static Builder builder(ChatModel chatModel) { - return new Builder(chatModel); + return new DefaultChatClientBuilder(chatModel); } - ChatClientRequest prompt(); + ChatClientRequestSpec prompt(); - ChatClientPromptRequest prompt(Prompt prompt); + ChatClientPromptRequestSpec prompt(Prompt prompt); /** * Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose - * settings are replicated from the default {@link ChatClientRequest} of this client. + * settings are replicated from the default {@link ChatClientRequestSpec} of this + * client. */ Builder mutate(); - interface PromptSpec { + interface PromptUserSpec { - T text(String text); + PromptUserSpec text(String text); - T text(Resource text, Charset charset); + PromptUserSpec text(Resource text, Charset charset); - T text(Resource text); + PromptUserSpec text(Resource text); - T params(Map p); + PromptUserSpec params(Map p); - T param(String k, Object v); + PromptUserSpec param(String k, Object v); + + PromptUserSpec media(Media... media); + + PromptUserSpec media(MimeType mimeType, URL url); + + PromptUserSpec media(MimeType mimeType, Resource resource); + + List media(); } - abstract class AbstractPromptSpec> implements PromptSpec { + interface PromptSystemSpec { - private String text = ""; + PromptSystemSpec text(String text); - private final Map params = new HashMap<>(); + PromptSystemSpec text(Resource text, Charset charset); - @Override - public T text(String text) { - this.text = text; - return self(); - } + PromptSystemSpec text(Resource text); - @Override - public T text(Resource text, Charset charset) { - try { - this.text(text.getContentAsString(charset)); - } - catch (IOException e) { - throw new RuntimeException(e); - } - return self(); - } + PromptSystemSpec params(Map p); - @Override - public T text(Resource text) { - this.text(text, Charset.defaultCharset()); - return self(); - } - - @Override - public T param(String k, Object v) { - this.params.put(k, v); - return self(); - } - - @Override - public T params(Map p) { - this.params.putAll(p); - return self(); - } - - protected abstract T self(); - - protected String text() { - return this.text; - } - - protected Map params() { - return this.params; - } + PromptSystemSpec param(String k, Object v); } - class UserSpec extends AbstractPromptSpec implements PromptSpec { + interface AdvisorSpec { - private final List media = new ArrayList<>(); + AdvisorSpec param(String k, Object v); - public UserSpec media(Media... media) { - this.media.addAll(Arrays.asList(media)); - return self(); - } + AdvisorSpec params(Map p); - public UserSpec media(MimeType mimeType, URL url) { - this.media.add(new Media(mimeType, url)); - return self(); - } + AdvisorSpec advisors(RequestResponseAdvisor... advisors); - public UserSpec media(MimeType mimeType, Resource resource) { - this.media.add(new Media(mimeType, resource)); - return self(); - } - - protected List media() { - return this.media; - } - - @Override - protected UserSpec self() { - return this; - } + AdvisorSpec advisors(List advisors); } - class SystemSpec extends AbstractPromptSpec implements PromptSpec { + interface CallResponseSpec { - @Override - protected SystemSpec self() { - return this; - } + T entity(ParameterizedTypeReference type); + + T entity(StructuredOutputConverter structuredOutputConverter); + + T entity(Class type); + + ChatResponse chatResponse(); + + String content(); } - class ChatClientPromptRequest { + interface StreamResponseSpec { - private final ChatModel chatModel; + Flux chatResponse(); - private final Prompt prompt; - - public ChatClientPromptRequest(ChatModel chatModel, Prompt prompt) { - this.chatModel = chatModel; - this.prompt = prompt; - } - - public ChatClientRequest.CallPromptResponseSpec call() { - return new ChatClientRequest.CallPromptResponseSpec(this.chatModel, this.prompt); - } - - public ChatClientRequest.StreamPromptResponseSpec stream() { - return new ChatClientRequest.StreamPromptResponseSpec((StreamingChatModel) this.chatModel, this.prompt); - } + Flux content(); } - class AdvisorSpec { + interface ChatClientPromptRequestSpec { - private List advisors = new ArrayList<>(); + CallPromptResponseSpec call(); - private final Map params = new HashMap<>(); - - public AdvisorSpec param(String k, Object v) { - this.params.put(k, v); - return this; - } - - public AdvisorSpec params(Map p) { - this.params.putAll(p); - return this; - } - - public AdvisorSpec advisors(RequestResponseAdvisor... advisors) { - this.advisors.addAll(List.of(advisors)); - return this; - } - - public AdvisorSpec advisors(List advisors) { - this.advisors.addAll(advisors); - return this; - } + StreamPromptResponseSpec stream(); } - class ChatClientRequest { + interface CallPromptResponseSpec { - private final ChatModel chatModel; + String content(); - private String userText = ""; + List contents(); - private String systemText = ""; + ChatResponse chatResponse(); - private ChatOptions chatOptions; + } - private final List media = new ArrayList<>(); + interface StreamPromptResponseSpec { - private final List functionNames = new ArrayList<>(); + Flux chatResponse(); - private final List functionCallbacks = new ArrayList<>(); + public Flux content(); - private final List messages = new ArrayList<>(); + } - private final Map userParams = new HashMap<>(); - - private final Map systemParams = new HashMap<>(); - - private List advisors = new ArrayList<>(); - - private final Map advisorParams = new HashMap<>(); - - /* copy constructor */ - ChatClientRequest(ChatClientRequest ccr) { - this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, - ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams); - } - - public ChatClientRequest(ChatModel chatModel, String userText, Map userParams, - String systemText, Map systemParams, List functionCallbacks, - List messages, List functionNames, List media, ChatOptions chatOptions, - List advisors, Map advisorParams) { - - this.chatModel = chatModel; - this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions(); - - this.userText = userText; - this.userParams.putAll(userParams); - this.systemText = systemText; - this.systemParams.putAll(systemParams); - - this.functionNames.addAll(functionNames); - this.functionCallbacks.addAll(functionCallbacks); - this.messages.addAll(messages); - this.media.addAll(media); - this.advisors.addAll(advisors); - this.advisorParams.putAll(advisorParams); - } + interface ChatClientRequestSpec { /** * Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose * settings are replicated from this {@code ChatClientRequest}. */ - public Builder mutate() { - Builder builder = ChatClient.builder(chatModel) - .defaultSystem(s -> s.text(this.systemText).params(this.systemParams)) - .defaultUser(u -> u.text(this.userText) - .params(this.userParams) - .media(this.media.toArray(new Media[this.media.size()]))) - .defaultOptions(this.chatOptions) - .defaultFunctions(StringUtils.toStringArray(this.functionNames)); + Builder mutate(); - // workaround to set the missing fields. - builder.defaultRequest.messages.addAll(this.messages); - builder.defaultRequest.functionCallbacks.addAll(this.functionCallbacks); + ChatClientRequestSpec advisors(Consumer consumer); - return builder; - } + ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors); - public ChatClientRequest advisors(Consumer consumer) { - Assert.notNull(consumer, "the consumer must be non-null"); - var as = new AdvisorSpec(); - consumer.accept(as); - this.advisorParams.putAll(as.params); - this.advisors.addAll(as.advisors); - return this; - } + ChatClientRequestSpec advisors(List advisors); - public ChatClientRequest advisors(RequestResponseAdvisor... advisors) { - Assert.notNull(advisors, "the advisors must be non-null"); - this.advisors.addAll(List.of(advisors)); - return this; - } + ChatClientRequestSpec messages(Message... messages); - public ChatClientRequest advisors(List advisors) { - Assert.notNull(advisors, "the advisors must be non-null"); - this.advisors.addAll(advisors); - return this; - } + ChatClientRequestSpec messages(List messages); - public ChatClientRequest messages(Message... messages) { - Assert.notNull(messages, "the messages must be non-null"); - this.messages.addAll(List.of(messages)); - return this; - } + ChatClientRequestSpec options(T options); - public ChatClientRequest messages(List messages) { - Assert.notNull(messages, "the messages must be non-null"); - this.messages.addAll(messages); - return this; - } + ChatClientRequestSpec function(String name, String description, + java.util.function.Function function); - public ChatClientRequest options(T options) { - Assert.notNull(options, "the options must be non-null"); - this.chatOptions = options; - return this; - } + ChatClientRequestSpec functions(String... functionBeanNames); - public ChatClientRequest function(String name, String description, - java.util.function.Function function) { + ChatClientRequestSpec system(String text); - Assert.hasText(name, "the name must be non-null and non-empty"); - Assert.hasText(description, "the description must be non-null and non-empty"); - Assert.notNull(function, "the function must be non-null"); + ChatClientRequestSpec system(Resource textResource, Charset charset); - var fcw = FunctionCallbackWrapper.builder(function) - .withDescription(description) - .withName(name) - .withResponseConverter(Object::toString) - .build(); - this.functionCallbacks.add(fcw); - return this; - } + ChatClientRequestSpec system(Resource text); - public ChatClientRequest functions(String... functionBeanNames) { - Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null"); - this.functionNames.addAll(List.of(functionBeanNames)); - return this; - } + ChatClientRequestSpec system(Consumer consumer); - public ChatClientRequest system(String text) { - Assert.notNull(text, "the text must be non-null"); - this.systemText = text; - return this; - } + ChatClientRequestSpec user(String text); - public ChatClientRequest system(Resource textResource, Charset charset) { + ChatClientRequestSpec user(Resource text, Charset charset); - Assert.notNull(textResource, "the text resource must be non-null"); - Assert.notNull(charset, "the charset must be non-null"); + ChatClientRequestSpec user(Resource text); - try { - this.systemText = textResource.getContentAsString(charset); - } - catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } + ChatClientRequestSpec user(Consumer consumer); - public ChatClientRequest system(Resource text) { - Assert.notNull(text, "the text resource must be non-null"); - return this.system(text, Charset.defaultCharset()); - } + // ChatClientRequestSpec adviseOnRequest(ChatClientRequestSpec inputRequest, + // Map context); - public ChatClientRequest system(Consumer consumer) { + CallResponseSpec call(); - Assert.notNull(consumer, "the consumer must be non-null"); - - var ss = new SystemSpec(); - consumer.accept(ss); - this.systemText = StringUtils.hasText(ss.text()) ? ss.text() : this.systemText; - this.systemParams.putAll(ss.params()); - - return this; - } - - public ChatClientRequest user(String text) { - Assert.notNull(text, "the text must be non-null"); - this.userText = text; - return this; - } - - public ChatClientRequest user(Resource text, Charset charset) { - - Assert.notNull(text, "the text resource must be non-null"); - Assert.notNull(charset, "the charset must be non-null"); - - try { - this.userText = text.getContentAsString(charset); - } - catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } - - public ChatClientRequest user(Resource text) { - Assert.notNull(text, "the text resource must be non-null"); - return this.user(text, Charset.defaultCharset()); - } - - public ChatClientRequest user(Consumer consumer) { - Assert.notNull(consumer, "the consumer must be non-null"); - - var us = new UserSpec(); - consumer.accept(us); - this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText; - this.userParams.putAll(us.params()); - this.media.addAll(us.media()); - return this; - } - - public static class StreamPromptResponseSpec { - - private final Prompt prompt; - - private final StreamingChatModel chatModel; - - public StreamPromptResponseSpec(StreamingChatModel streamingChatModel, Prompt prompt) { - this.chatModel = streamingChatModel; - this.prompt = prompt; - } - - public Flux chatResponse() { - return doGetFluxChatResponse(this.prompt); - } - - private Flux doGetFluxChatResponse(Prompt prompt) { - return this.chatModel.stream(prompt); - } - - public Flux content() { - return doGetFluxChatResponse(this.prompt).map(r -> { - if (r.getResult() == null || r.getResult().getOutput() == null - || r.getResult().getOutput().getContent() == null) { - return ""; - } - return r.getResult().getOutput().getContent(); - }).filter(v -> StringUtils.hasText(v)); - } - - } - - public static class CallPromptResponseSpec { - - private final ChatModel chatModel; - - private final Prompt prompt; - - public CallPromptResponseSpec(ChatModel chatModel, Prompt prompt) { - this.chatModel = chatModel; - this.prompt = prompt; - } - - public String content() { - return doGetChatResponse(this.prompt).getResult().getOutput().getContent(); - } - - public List contents() { - return doGetChatResponse(this.prompt).getResults() - .stream() - .map(r -> r.getOutput().getContent()) - .toList(); - } - - public ChatResponse chatResponse() { - return doGetChatResponse(this.prompt); - } - - private ChatResponse doGetChatResponse(Prompt prompt) { - return chatModel.call(prompt); - } - - } - - private static ChatClientRequest adviseOnRequest(ChatClientRequest inputRequest, Map context) { - - ChatClientRequest advisedRequest = inputRequest; - - if (!CollectionUtils.isEmpty(inputRequest.advisors)) { - AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, - inputRequest.systemText, inputRequest.chatOptions, inputRequest.media, - inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages, - inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors, - inputRequest.advisorParams); - - // apply the advisors onRequest - var currentAdvisors = new ArrayList<>(inputRequest.advisors); - for (RequestResponseAdvisor advisor : currentAdvisors) { - adviseRequest = advisor.adviseRequest(adviseRequest, context); - } - - advisedRequest = new ChatClientRequest(adviseRequest.chatModel(), adviseRequest.userText(), - adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(), - adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(), - adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(), - adviseRequest.advisorParams()); - } - - return advisedRequest; - } - - public static class CallResponseSpec { - - private final ChatClientRequest request; - - private final ChatModel chatModel; - - public CallResponseSpec(ChatModel chatModel, ChatClientRequest request) { - this.chatModel = chatModel; - this.request = request; - } - - public T entity(ParameterizedTypeReference type) { - return doSingleWithBeanOutputConverter(new BeanOutputConverter(type)); - } - - public T entity(StructuredOutputConverter structuredOutputConverter) { - return doSingleWithBeanOutputConverter(structuredOutputConverter); - } - - private T doSingleWithBeanOutputConverter(StructuredOutputConverter boc) { - var chatResponse = doGetChatResponse(this.request, boc.getFormat()); - var stringResponse = chatResponse.getResult().getOutput().getContent(); - return boc.convert(stringResponse); - } - - public T entity(Class type) { - Assert.notNull(type, "the class must be non-null"); - var boc = new BeanOutputConverter(type); - return doSingleWithBeanOutputConverter(boc); - } - - private ChatResponse doGetChatResponse() { - return this.doGetChatResponse(this.request, ""); - } - - private ChatResponse doGetChatResponse(ChatClientRequest inputRequest, String formatParam) { - - Map context = new ConcurrentHashMap<>(); - context.putAll(inputRequest.advisorParams); - ChatClientRequest advisedRequest = adviseOnRequest(inputRequest, context); - - var processedUserText = StringUtils.hasText(formatParam) - ? advisedRequest.userText + System.lineSeparator() + "{spring_ai_soc_format}" - : advisedRequest.userText; - - Map userParams = new HashMap<>(advisedRequest.userParams); - if (StringUtils.hasText(formatParam)) { - userParams.put("spring_ai_soc_format", formatParam); - } - - var messages = new ArrayList(advisedRequest.messages); - var textsAreValid = (StringUtils.hasText(processedUserText) - || StringUtils.hasText(advisedRequest.systemText)); - if (textsAreValid) { - if (StringUtils.hasText(advisedRequest.systemText) || !advisedRequest.systemParams.isEmpty()) { - var systemMessage = new SystemMessage( - new PromptTemplate(advisedRequest.systemText, advisedRequest.systemParams).render()); - messages.add(systemMessage); - } - UserMessage userMessage = null; - if (!CollectionUtils.isEmpty(userParams)) { - userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), - advisedRequest.media); - } - else { - userMessage = new UserMessage(processedUserText, advisedRequest.media); - } - messages.add(userMessage); - } - - if (advisedRequest.chatOptions instanceof FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.functionNames.isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames)); - } - if (!advisedRequest.functionCallbacks.isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks); - } - } - var prompt = new Prompt(messages, advisedRequest.chatOptions); - var chatResponse = this.chatModel.call(prompt); - - ChatResponse advisedResponse = chatResponse; - // apply the advisors on response - if (!CollectionUtils.isEmpty(inputRequest.advisors)) { - var currentAdvisors = new ArrayList<>(inputRequest.advisors); - for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, context); - } - } - - return advisedResponse; - } - - public ChatResponse chatResponse() { - return doGetChatResponse(); - } - - public String content() { - return doGetChatResponse().getResult().getOutput().getContent(); - } - - } - - public static class StreamResponseSpec { - - private final ChatClientRequest request; - - private final StreamingChatModel chatModel; - - public StreamResponseSpec(StreamingChatModel streamingChatModel, ChatClientRequest request) { - this.chatModel = streamingChatModel; - this.request = request; - } - - private Flux doGetFluxChatResponse(ChatClientRequest inputRequest) { - - Map context = new ConcurrentHashMap<>(); - context.putAll(inputRequest.advisorParams); - ChatClientRequest advisedRequest = adviseOnRequest(inputRequest, context); - - String processedUserText = advisedRequest.userText; - Map userParams = new HashMap<>(advisedRequest.userParams); - - var messages = new ArrayList(advisedRequest.messages); - var textsAreValid = (StringUtils.hasText(processedUserText) - || StringUtils.hasText(advisedRequest.systemText)); - if (textsAreValid) { - UserMessage userMessage = null; - if (!CollectionUtils.isEmpty(userParams)) { - userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), - advisedRequest.media); - } - else { - userMessage = new UserMessage(processedUserText, advisedRequest.media); - } - if (StringUtils.hasText(advisedRequest.systemText) || !advisedRequest.systemParams.isEmpty()) { - var systemMessage = new SystemMessage( - new PromptTemplate(advisedRequest.systemText, advisedRequest.systemParams).render()); - messages.add(systemMessage); - } - messages.add(userMessage); - } - - if (advisedRequest.chatOptions instanceof - - FunctionCallingOptions functionCallingOptions) { - if (!advisedRequest.functionNames.isEmpty()) { - functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames)); - } - if (!advisedRequest.functionCallbacks.isEmpty()) { - functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks); - } - } - var prompt = new Prompt(messages, advisedRequest.chatOptions); - - var fluxChatResponse = this.chatModel.stream(prompt); - - Flux advisedResponse = fluxChatResponse; - // apply the advisors on response - if (!CollectionUtils.isEmpty(inputRequest.advisors)) { - var currentAdvisors = new ArrayList<>(inputRequest.advisors); - for (RequestResponseAdvisor advisor : currentAdvisors) { - advisedResponse = advisor.adviseResponse(advisedResponse, context); - } - } - - return advisedResponse; - } - - public Flux chatResponse() { - return doGetFluxChatResponse(this.request); - } - - public Flux content() { - return doGetFluxChatResponse(this.request).map(r -> { - if (r.getResult() == null || r.getResult().getOutput() == null - || r.getResult().getOutput().getContent() == null) { - return ""; - } - return r.getResult().getOutput().getContent(); - }).filter(v -> StringUtils.hasText(v)); - } - - } - - public CallResponseSpec call() { - return new CallResponseSpec(this.chatModel, this); - } - - public StreamResponseSpec stream() { - return new StreamResponseSpec((StreamingChatModel) this.chatModel, this); - } + StreamResponseSpec stream(); } - class Builder { + /** + * A mutable builder for creating a {@link ChatClient}. + */ + interface Builder { - private final ChatClientRequest defaultRequest; + Builder defaultAdvisors(RequestResponseAdvisor... advisor); - private final ChatModel chatModel; + Builder defaultAdvisors(Consumer advisorSpecConsumer); - Builder(ChatModel chatModel) { - Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); - this.chatModel = chatModel; - this.defaultRequest = new ChatClientRequest(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(), - List.of(), List.of(), null, List.of(), Map.of()); - } + Builder defaultAdvisors(List advisors); - public Builder defaultAdvisors(RequestResponseAdvisor... advisor) { - this.defaultRequest.advisors(advisor); - return this; - } + ChatClient build(); - public Builder defaultAdvisors(Consumer advisorSpecConsumer) { - this.defaultRequest.advisors(advisorSpecConsumer); - return this; - } + Builder defaultOptions(ChatOptions chatOptions); - public Builder defaultAdvisors(List advisors) { - this.defaultRequest.advisors(advisors); - return this; - } + Builder defaultUser(String text); - public ChatClient build() { - return new DefaultChatClient(this.chatModel, this.defaultRequest); - } + Builder defaultUser(Resource text, Charset charset); - public Builder defaultOptions(ChatOptions chatOptions) { - this.defaultRequest.options(chatOptions); - return this; - } + Builder defaultUser(Resource text); - public Builder defaultUser(String text) { - this.defaultRequest.user(text); - return this; - } + Builder defaultUser(Consumer userSpecConsumer); - public Builder defaultUser(Resource text, Charset charset) { - try { - this.defaultRequest.user(text.getContentAsString(charset)); - } - catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } + Builder defaultSystem(String text); - public Builder defaultUser(Resource text) { - return this.defaultUser(text, Charset.defaultCharset()); - } + Builder defaultSystem(Resource text, Charset charset); - public Builder defaultUser(Consumer userSpecConsumer) { - this.defaultRequest.user(userSpecConsumer); - return this; - } + Builder defaultSystem(Resource text); - public Builder defaultSystem(String text) { - this.defaultRequest.system(text); - return this; - } + Builder defaultSystem(Consumer systemSpecConsumer); - public Builder defaultSystem(Resource text, Charset charset) { - try { - this.defaultRequest.system(text.getContentAsString(charset)); - } - catch (IOException e) { - throw new RuntimeException(e); - } - return this; - } + Builder defaultFunction(String name, String description, java.util.function.Function function); - public Builder defaultSystem(Resource text) { - return this.defaultSystem(text, Charset.defaultCharset()); - } - - public Builder defaultSystem(Consumer systemSpecConsumer) { - this.defaultRequest.system(systemSpecConsumer); - return this; - } - - public Builder defaultFunction(String name, String description, - java.util.function.Function function) { - this.defaultRequest.function(name, description, function); - return this; - } - - public Builder defaultFunctions(String... functionNames) { - this.defaultRequest.functions(functionNames); - return this; - } + Builder defaultFunctions(String... functionNames); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 8894d0852..ffb4eb2cf 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -1,42 +1,74 @@ package org.springframework.ai.chat.client; +import java.io.IOException; +import java.net.URL; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.messages.Media; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.StreamingChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.StructuredOutputConverter; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.io.Resource; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; +import org.springframework.util.StringUtils; /** * The default implementation of {@link ChatClient} as created by the - * {@link ChatClient.Builder#build()} } method. + * {@link Builder#build()} } method. * * @author Mark Pollack * @author Christian Tzolov * @author Josh Long * @author Arjen Poutsma - * @since 1.0.0 M1 + * @since 1.0.0 */ -class DefaultChatClient implements ChatClient { +public class DefaultChatClient implements ChatClient { private final ChatModel chatModel; - private final ChatClientRequest defaultChatClientRequest; + private final DefaultChatClientRequestSpec defaultChatClientRequest; - public DefaultChatClient(ChatModel chatModel, ChatClientRequest defaultChatClientRequest) { + public DefaultChatClient(ChatModel chatModel, DefaultChatClientRequestSpec defaultChatClientRequest) { this.chatModel = chatModel; this.defaultChatClientRequest = defaultChatClientRequest; } @Override - public ChatClientRequest prompt() { - return new ChatClientRequest(this.defaultChatClientRequest); + public ChatClientRequestSpec prompt() { + return new DefaultChatClientRequestSpec(this.defaultChatClientRequest); } @Override - public ChatClientPromptRequest prompt(Prompt prompt) { - return new ChatClientPromptRequest(this.chatModel, prompt); + public ChatClientPromptRequestSpec prompt(Prompt prompt) { + return new DefaultChatClientPromptRequestSpec(this.chatModel, prompt); } /** - * Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose + * Return a {@code ChatClient2Builder} to create a new {@code ChatClient} whose * settings are replicated from this {@code ChatClientRequest}. */ @Override @@ -44,6 +76,732 @@ class DefaultChatClient implements ChatClient { return this.defaultChatClientRequest.mutate(); } + public static class DefaultPromptUserSpec implements PromptUserSpec { + + private String text = ""; + + private final Map params = new HashMap<>(); + + private final List media = new ArrayList<>(); + + @Override + public PromptUserSpec media(Media... media) { + this.media.addAll(Arrays.asList(media)); + return this; + } + + @Override + public PromptUserSpec media(MimeType mimeType, URL url) { + this.media.add(new Media(mimeType, url)); + return this; + } + + @Override + public PromptUserSpec media(MimeType mimeType, Resource resource) { + this.media.add(new Media(mimeType, resource)); + return this; + } + + @Override + public PromptUserSpec text(String text) { + this.text = text; + return this; + } + + @Override + public PromptUserSpec text(Resource text, Charset charset) { + try { + this.text(text.getContentAsString(charset)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override + public PromptUserSpec text(Resource text) { + this.text(text, Charset.defaultCharset()); + return this; + } + + @Override + public PromptUserSpec param(String k, Object v) { + this.params.put(k, v); + return this; + } + + @Override + public PromptUserSpec params(Map p) { + this.params.putAll(p); + return this; + } + + protected String text() { + return this.text; + } + + protected Map params() { + return this.params; + } + + @Override + public List media() { + return this.media; + } + + } + + public static class DefaultPromptSystemSpec implements PromptSystemSpec { + + private String text = ""; + + private final Map params = new HashMap<>(); + + @Override + public PromptSystemSpec text(String text) { + this.text = text; + return this; + } + + @Override + public PromptSystemSpec text(Resource text, Charset charset) { + try { + this.text(text.getContentAsString(charset)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override + public PromptSystemSpec text(Resource text) { + this.text(text, Charset.defaultCharset()); + return this; + } + + @Override + public PromptSystemSpec param(String k, Object v) { + this.params.put(k, v); + return this; + } + + @Override + public PromptSystemSpec params(Map p) { + this.params.putAll(p); + return this; + } + + protected String text() { + return this.text; + } + + protected Map params() { + return this.params; + } + + } + + public static class DefaultAdvisorSpec implements AdvisorSpec { + + private List advisors = new ArrayList<>(); + + private final Map params = new HashMap<>(); + + public AdvisorSpec param(String k, Object v) { + this.params.put(k, v); + return this; + } + + public AdvisorSpec params(Map p) { + this.params.putAll(p); + return this; + } + + public AdvisorSpec advisors(RequestResponseAdvisor... advisors) { + this.advisors.addAll(List.of(advisors)); + return this; + } + + public AdvisorSpec advisors(List advisors) { + this.advisors.addAll(advisors); + return this; + } + + public List getAdvisors() { + return advisors; + } + + public Map getParams() { + return params; + } + + } + + public static class DefaultCallResponseSpec implements CallResponseSpec { + + private final DefaultChatClientRequestSpec request; + + private final ChatModel chatModel; + + public DefaultCallResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) { + this.chatModel = chatModel; + this.request = request; + } + + public T entity(ParameterizedTypeReference type) { + return doSingleWithBeanOutputConverter(new BeanOutputConverter(type)); + } + + public T entity(StructuredOutputConverter structuredOutputConverter) { + return doSingleWithBeanOutputConverter(structuredOutputConverter); + } + + private T doSingleWithBeanOutputConverter(StructuredOutputConverter boc) { + var chatResponse = doGetChatResponse(this.request, boc.getFormat()); + var stringResponse = chatResponse.getResult().getOutput().getContent(); + return boc.convert(stringResponse); + } + + public T entity(Class type) { + Assert.notNull(type, "the class must be non-null"); + var boc = new BeanOutputConverter(type); + return doSingleWithBeanOutputConverter(boc); + } + + private ChatResponse doGetChatResponse() { + return this.doGetChatResponse(this.request, ""); + } + + private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest, String formatParam) { + + Map context = new ConcurrentHashMap<>(); + context.putAll(inputRequest.getAdvisorParams()); + DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest, + context); + + var processedUserText = StringUtils.hasText(formatParam) + ? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}" + : advisedRequest.getUserText(); + + Map userParams = new HashMap<>(advisedRequest.getUserParams()); + if (StringUtils.hasText(formatParam)) { + userParams.put("spring_ai_soc_format", formatParam); + } + + var messages = new ArrayList(advisedRequest.getMessages()); + var textsAreValid = (StringUtils.hasText(processedUserText) + || StringUtils.hasText(advisedRequest.getSystemText())); + if (textsAreValid) { + if (StringUtils.hasText(advisedRequest.getSystemText()) + || !advisedRequest.getSystemParams().isEmpty()) { + var systemMessage = new SystemMessage( + new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams()) + .render()); + messages.add(systemMessage); + } + UserMessage userMessage = null; + if (!CollectionUtils.isEmpty(userParams)) { + userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), + advisedRequest.getMedia()); + } + else { + userMessage = new UserMessage(processedUserText, advisedRequest.getMedia()); + } + messages.add(userMessage); + } + + if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) { + if (!advisedRequest.getFunctionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); + } + if (!advisedRequest.getFunctionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); + } + } + var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + var chatResponse = this.chatModel.call(prompt); + + ChatResponse advisedResponse = chatResponse; + // apply the advisors on response + if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { + var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); + for (RequestResponseAdvisor advisor : currentAdvisors) { + advisedResponse = advisor.adviseResponse(advisedResponse, context); + } + } + + return advisedResponse; + } + + public ChatResponse chatResponse() { + return doGetChatResponse(); + } + + public String content() { + return doGetChatResponse().getResult().getOutput().getContent(); + } + + } + + public static class DefaultStreamResponseSpec implements StreamResponseSpec { + + private final DefaultChatClientRequestSpec request; + + private final ChatModel chatModel; + + public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) { + this.chatModel = chatModel; + this.request = request; + } + + private Flux doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) { + + Map context = new ConcurrentHashMap<>(); + context.putAll(inputRequest.getAdvisorParams()); + DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest, + context); + + String processedUserText = advisedRequest.getUserText(); + Map userParams = new HashMap<>(advisedRequest.getUserParams()); + + var messages = new ArrayList(advisedRequest.getMessages()); + var textsAreValid = (StringUtils.hasText(processedUserText) + || StringUtils.hasText(advisedRequest.getSystemText())); + if (textsAreValid) { + UserMessage userMessage = null; + if (!CollectionUtils.isEmpty(userParams)) { + userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(), + advisedRequest.getMedia()); + } + else { + userMessage = new UserMessage(processedUserText, advisedRequest.getMedia()); + } + if (StringUtils.hasText(advisedRequest.getSystemText()) + || !advisedRequest.getSystemParams().isEmpty()) { + var systemMessage = new SystemMessage( + new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams()) + .render()); + messages.add(systemMessage); + } + messages.add(userMessage); + } + + if (advisedRequest.getChatOptions() instanceof + + FunctionCallingOptions functionCallingOptions) { + if (!advisedRequest.getFunctionNames().isEmpty()) { + functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames())); + } + if (!advisedRequest.getFunctionCallbacks().isEmpty()) { + functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks()); + } + } + var prompt = new Prompt(messages, advisedRequest.getChatOptions()); + + var fluxChatResponse = this.chatModel.stream(prompt); + + Flux advisedResponse = fluxChatResponse; + // apply the advisors on response + if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) { + var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors()); + for (RequestResponseAdvisor advisor : currentAdvisors) { + advisedResponse = advisor.adviseResponse(advisedResponse, context); + } + } + + return advisedResponse; + } + + public Flux chatResponse() { + return doGetFluxChatResponse(this.request); + } + + public Flux content() { + return doGetFluxChatResponse(this.request).map(r -> { + if (r.getResult() == null || r.getResult().getOutput() == null + || r.getResult().getOutput().getContent() == null) { + return ""; + } + return r.getResult().getOutput().getContent(); + }).filter(v -> StringUtils.hasText(v)); + } + + } + + public static class DefaultChatClientRequestSpec implements ChatClientRequestSpec { + + private final ChatModel chatModel; + + private String userText = ""; + + private String systemText = ""; + + private ChatOptions chatOptions; + + private final List media = new ArrayList<>(); + + private final List functionNames = new ArrayList<>(); + + private final List functionCallbacks = new ArrayList<>(); + + private final List messages = new ArrayList<>(); + + private final Map userParams = new HashMap<>(); + + private final Map systemParams = new HashMap<>(); + + private List advisors = new ArrayList<>(); + + private final Map advisorParams = new HashMap<>(); + + public String getUserText() { + return userText; + } + + public Map getUserParams() { + return userParams; + } + + public String getSystemText() { + return systemText; + } + + public Map getSystemParams() { + return systemParams; + } + + public ChatOptions getChatOptions() { + return chatOptions; + } + + public List getAdvisors() { + return advisors; + } + + public Map getAdvisorParams() { + return advisorParams; + } + + public List getMessages() { + return messages; + } + + public List getMedia() { + return media; + } + + public List getFunctionNames() { + return this.functionNames; + } + + public List getFunctionCallbacks() { + return functionCallbacks; + } + + /* copy constructor */ + DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) { + this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks, + ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams); + } + + public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map userParams, + String systemText, Map systemParams, List functionCallbacks, + List messages, List functionNames, List media, ChatOptions chatOptions, + List advisors, Map advisorParams) { + + this.chatModel = chatModel; + this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions(); + + this.userText = userText; + this.userParams.putAll(userParams); + this.systemText = systemText; + this.systemParams.putAll(systemParams); + + this.functionNames.addAll(functionNames); + this.functionCallbacks.addAll(functionCallbacks); + this.messages.addAll(messages); + this.media.addAll(media); + this.advisors.addAll(advisors); + this.advisorParams.putAll(advisorParams); + } + + /** + * Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose + * settings are replicated from this {@code ChatClientRequest}. + */ + public Builder mutate() { + DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient.builder(chatModel) + .defaultSystem(s -> s.text(this.systemText).params(this.systemParams)) + .defaultUser(u -> u.text(this.userText) + .params(this.userParams) + .media(this.media.toArray(new Media[this.media.size()]))) + .defaultOptions(this.chatOptions) + .defaultFunctions(StringUtils.toStringArray(this.functionNames)); + + // workaround to set the missing fields. + builder.defaultRequest.getMessages().addAll(this.messages); + builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks); + + return builder; + } + + public ChatClientRequestSpec advisors(Consumer consumer) { + Assert.notNull(consumer, "the consumer must be non-null"); + var as = new DefaultAdvisorSpec(); + consumer.accept(as); + this.advisorParams.putAll(as.getParams()); + this.advisors.addAll(as.getAdvisors()); + return this; + } + + public ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors) { + Assert.notNull(advisors, "the advisors must be non-null"); + this.advisors.addAll(List.of(advisors)); + return this; + } + + public ChatClientRequestSpec advisors(List advisors) { + Assert.notNull(advisors, "the advisors must be non-null"); + this.advisors.addAll(advisors); + return this; + } + + public ChatClientRequestSpec messages(Message... messages) { + Assert.notNull(messages, "the messages must be non-null"); + this.messages.addAll(List.of(messages)); + return this; + } + + public ChatClientRequestSpec messages(List messages) { + Assert.notNull(messages, "the messages must be non-null"); + this.messages.addAll(messages); + return this; + } + + public ChatClientRequestSpec options(T options) { + Assert.notNull(options, "the options must be non-null"); + this.chatOptions = options; + return this; + } + + public ChatClientRequestSpec function(String name, String description, + java.util.function.Function function) { + + Assert.hasText(name, "the name must be non-null and non-empty"); + Assert.hasText(description, "the description must be non-null and non-empty"); + Assert.notNull(function, "the function must be non-null"); + + var fcw = FunctionCallbackWrapper.builder(function) + .withDescription(description) + .withName(name) + .withResponseConverter(Object::toString) + .build(); + this.functionCallbacks.add(fcw); + return this; + } + + public ChatClientRequestSpec functions(String... functionBeanNames) { + Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null"); + this.functionNames.addAll(List.of(functionBeanNames)); + return this; + } + + public ChatClientRequestSpec system(String text) { + Assert.notNull(text, "the text must be non-null"); + this.systemText = text; + return this; + } + + public ChatClientRequestSpec system(Resource textResource, Charset charset) { + + Assert.notNull(textResource, "the text resource must be non-null"); + Assert.notNull(charset, "the charset must be non-null"); + + try { + this.systemText = textResource.getContentAsString(charset); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + public ChatClientRequestSpec system(Resource text) { + Assert.notNull(text, "the text resource must be non-null"); + return this.system(text, Charset.defaultCharset()); + } + + public ChatClientRequestSpec system(Consumer consumer) { + + Assert.notNull(consumer, "the consumer must be non-null"); + + var ss = new DefaultPromptSystemSpec(); + consumer.accept(ss); + this.systemText = StringUtils.hasText(ss.text()) ? ss.text() : this.systemText; + this.systemParams.putAll(ss.params()); + + return this; + } + + public ChatClientRequestSpec user(String text) { + Assert.notNull(text, "the text must be non-null"); + this.userText = text; + return this; + } + + public ChatClientRequestSpec user(Resource text, Charset charset) { + + Assert.notNull(text, "the text resource must be non-null"); + Assert.notNull(charset, "the charset must be non-null"); + + try { + this.userText = text.getContentAsString(charset); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + public ChatClientRequestSpec user(Resource text) { + Assert.notNull(text, "the text resource must be non-null"); + return this.user(text, Charset.defaultCharset()); + } + + public ChatClientRequestSpec user(Consumer consumer) { + Assert.notNull(consumer, "the consumer must be non-null"); + + var us = new DefaultPromptUserSpec(); + consumer.accept(us); + this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText; + this.userParams.putAll(us.params()); + this.media.addAll(us.media()); + return this; + } + + public CallResponseSpec call() { + return new DefaultCallResponseSpec(chatModel, this); + } + + public StreamResponseSpec stream() { + return new DefaultStreamResponseSpec(chatModel, this); + } + + public static DefaultChatClientRequestSpec adviseOnRequest(DefaultChatClientRequestSpec inputRequest, + Map context) { + + DefaultChatClientRequestSpec advisedRequest = inputRequest; + + if (!CollectionUtils.isEmpty(inputRequest.advisors)) { + AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText, + inputRequest.systemText, inputRequest.chatOptions, inputRequest.media, + inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages, + inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors, + inputRequest.advisorParams); + + // apply the advisors onRequest + var currentAdvisors = new ArrayList<>(inputRequest.advisors); + for (RequestResponseAdvisor advisor : currentAdvisors) { + adviseRequest = advisor.adviseRequest(adviseRequest, context); + } + + advisedRequest = new DefaultChatClientRequestSpec(adviseRequest.chatModel(), adviseRequest.userText(), + adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(), + adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(), + adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(), + adviseRequest.advisorParams()); + } + + return advisedRequest; + } + + } + + // Prompt + + public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec { + + private final ChatModel chatModel; + + private final Prompt prompt; + + public DefaultCallPromptResponseSpec(ChatModel chatModel, Prompt prompt) { + this.chatModel = chatModel; + this.prompt = prompt; + } + + public String content() { + return doGetChatResponse(this.prompt).getResult().getOutput().getContent(); + } + + public List contents() { + return doGetChatResponse(this.prompt).getResults().stream().map(r -> r.getOutput().getContent()).toList(); + } + + public ChatResponse chatResponse() { + return doGetChatResponse(this.prompt); + } + + private ChatResponse doGetChatResponse(Prompt prompt) { + return chatModel.call(prompt); + } + + } + + public static class DefaultStreamPromptResponseSpec implements StreamPromptResponseSpec { + + private final Prompt prompt; + + private final StreamingChatModel chatModel; + + public DefaultStreamPromptResponseSpec(StreamingChatModel streamingChatModel, Prompt prompt) { + this.chatModel = streamingChatModel; + this.prompt = prompt; + } + + public Flux chatResponse() { + return doGetFluxChatResponse(this.prompt); + } + + private Flux doGetFluxChatResponse(Prompt prompt) { + return this.chatModel.stream(prompt); + } + + public Flux content() { + return doGetFluxChatResponse(this.prompt).map(r -> { + if (r.getResult() == null || r.getResult().getOutput() == null + || r.getResult().getOutput().getContent() == null) { + return ""; + } + return r.getResult().getOutput().getContent(); + }).filter(v -> StringUtils.hasText(v)); + } + + } + + public static class DefaultChatClientPromptRequestSpec implements ChatClientPromptRequestSpec { + + private final ChatModel chatModel; + + private final Prompt prompt; + + public DefaultChatClientPromptRequestSpec(ChatModel chatModel, Prompt prompt) { + this.chatModel = chatModel; + this.prompt = prompt; + } + + public CallPromptResponseSpec call() { + return new DefaultCallPromptResponseSpec(this.chatModel, this.prompt); + } + + public StreamPromptResponseSpec stream() { + return new DefaultStreamPromptResponseSpec((StreamingChatModel) this.chatModel, this.prompt); + } + + } + /** * use the new fluid DSL starting in {@link #prompt()} * @param prompt the {@link Prompt prompt} object @@ -55,4 +813,4 @@ class DefaultChatClient implements ChatClient { return this.chatModel.call(prompt); } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java new file mode 100644 index 000000000..815611ade --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -0,0 +1,137 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.springframework.ai.chat.client.ChatClient.Builder; +import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec; +import org.springframework.ai.chat.client.ChatClient.PromptUserSpec; +import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.core.io.Resource; +import org.springframework.util.Assert; + +/** + * @author Mark Pollack + * @author Christian Tzolov + * @author Josh Long + * @author Arjen Poutsma + * @since 1.0.0 + * + */ +public class DefaultChatClientBuilder implements Builder { + + protected final DefaultChatClientRequestSpec defaultRequest; + + private final ChatModel chatModel; + + public DefaultChatClientBuilder(ChatModel chatModel) { + Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null"); + this.chatModel = chatModel; + this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(), + List.of(), List.of(), List.of(), null, List.of(), Map.of()); + } + + public ChatClient build() { + return new DefaultChatClient(this.chatModel, this.defaultRequest); + } + + public Builder defaultAdvisors(RequestResponseAdvisor... advisor) { + this.defaultRequest.advisors(advisor); + return this; + } + + public Builder defaultAdvisors(Consumer advisorSpecConsumer) { + this.defaultRequest.advisors(advisorSpecConsumer); + return this; + } + + public Builder defaultAdvisors(List advisors) { + this.defaultRequest.advisors(advisors); + return this; + } + + public Builder defaultOptions(ChatOptions chatOptions) { + this.defaultRequest.options(chatOptions); + return this; + } + + public Builder defaultUser(String text) { + this.defaultRequest.user(text); + return this; + } + + public Builder defaultUser(Resource text, Charset charset) { + try { + this.defaultRequest.user(text.getContentAsString(charset)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + public Builder defaultUser(Resource text) { + return this.defaultUser(text, Charset.defaultCharset()); + } + + public Builder defaultUser(Consumer userSpecConsumer) { + this.defaultRequest.user(userSpecConsumer); + return this; + } + + public Builder defaultSystem(String text) { + this.defaultRequest.system(text); + return this; + } + + public Builder defaultSystem(Resource text, Charset charset) { + try { + this.defaultRequest.system(text.getContentAsString(charset)); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + public Builder defaultSystem(Resource text) { + return this.defaultSystem(text, Charset.defaultCharset()); + } + + public Builder defaultSystem(Consumer systemSpecConsumer) { + this.defaultRequest.system(systemSpecConsumer); + return this; + } + + public Builder defaultFunction(String name, String description, java.util.function.Function function) { + this.defaultRequest.function(name, description, function); + return this; + } + + public Builder defaultFunctions(String... functionNames) { + this.defaultRequest.functions(functionNames); + return this; + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java index 922898f65..699f3b523 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/RequestResponseAdvisor.java @@ -20,7 +20,6 @@ import java.util.Map; import reactor.core.publisher.Flux; -import org.springframework.ai.chat.client.ChatClient.ChatClientRequest; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; @@ -31,7 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt; * chain of advisors with chared execution context. * * @author Christian Tzolov - * @since 1.0.0 M1 + * @since 1.0.0 */ public interface RequestResponseAdvisor {