ChatClient/DefaultChatClient refactoring
Now the ChatClient contains only the Spec interfaces while all implementations are moved to the DefaultChatClient and the DefaultChatModelBuilder classes.
This commit is contained in:
@@ -15,41 +15,26 @@
|
||||
*/
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.StructuredOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackWrapper;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.MimeType;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* Client to perform stateless requests to an AI Model, using a fluent API.
|
||||
@@ -60,7 +45,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Christian Tzolov
|
||||
* @author Josh Long
|
||||
* @author Arjen Poutsma
|
||||
* @since 1.0.0 M1
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface ChatClient {
|
||||
|
||||
@@ -69,761 +54,200 @@ public interface ChatClient {
|
||||
}
|
||||
|
||||
static Builder builder(ChatModel chatModel) {
|
||||
return new Builder(chatModel);
|
||||
return new DefaultChatClientBuilder(chatModel);
|
||||
}
|
||||
|
||||
ChatClientRequest prompt();
|
||||
ChatClientRequestSpec prompt();
|
||||
|
||||
ChatClientPromptRequest prompt(Prompt prompt);
|
||||
ChatClientPromptRequestSpec prompt(Prompt prompt);
|
||||
|
||||
/**
|
||||
* Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose
|
||||
* settings are replicated from the default {@link ChatClientRequest} of this client.
|
||||
* settings are replicated from the default {@link ChatClientRequestSpec} of this
|
||||
* client.
|
||||
*/
|
||||
Builder mutate();
|
||||
|
||||
interface PromptSpec<T> {
|
||||
interface PromptUserSpec {
|
||||
|
||||
T text(String text);
|
||||
PromptUserSpec text(String text);
|
||||
|
||||
T text(Resource text, Charset charset);
|
||||
PromptUserSpec text(Resource text, Charset charset);
|
||||
|
||||
T text(Resource text);
|
||||
PromptUserSpec text(Resource text);
|
||||
|
||||
T params(Map<String, Object> p);
|
||||
PromptUserSpec params(Map<String, Object> p);
|
||||
|
||||
T param(String k, Object v);
|
||||
PromptUserSpec param(String k, Object v);
|
||||
|
||||
PromptUserSpec media(Media... media);
|
||||
|
||||
PromptUserSpec media(MimeType mimeType, URL url);
|
||||
|
||||
PromptUserSpec media(MimeType mimeType, Resource resource);
|
||||
|
||||
List<Media> media();
|
||||
|
||||
}
|
||||
|
||||
abstract class AbstractPromptSpec<T extends AbstractPromptSpec<T>> implements PromptSpec<T> {
|
||||
interface PromptSystemSpec {
|
||||
|
||||
private String text = "";
|
||||
PromptSystemSpec text(String text);
|
||||
|
||||
private final Map<String, Object> params = new HashMap<>();
|
||||
PromptSystemSpec text(Resource text, Charset charset);
|
||||
|
||||
@Override
|
||||
public T text(String text) {
|
||||
this.text = text;
|
||||
return self();
|
||||
}
|
||||
PromptSystemSpec text(Resource text);
|
||||
|
||||
@Override
|
||||
public T text(Resource text, Charset charset) {
|
||||
try {
|
||||
this.text(text.getContentAsString(charset));
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return self();
|
||||
}
|
||||
PromptSystemSpec params(Map<String, Object> p);
|
||||
|
||||
@Override
|
||||
public T text(Resource text) {
|
||||
this.text(text, Charset.defaultCharset());
|
||||
return self();
|
||||
}
|
||||
|
||||
@Override
|
||||
public T param(String k, Object v) {
|
||||
this.params.put(k, v);
|
||||
return self();
|
||||
}
|
||||
|
||||
@Override
|
||||
public T params(Map<String, Object> p) {
|
||||
this.params.putAll(p);
|
||||
return self();
|
||||
}
|
||||
|
||||
protected abstract T self();
|
||||
|
||||
protected String text() {
|
||||
return this.text;
|
||||
}
|
||||
|
||||
protected Map<String, Object> params() {
|
||||
return this.params;
|
||||
}
|
||||
PromptSystemSpec param(String k, Object v);
|
||||
|
||||
}
|
||||
|
||||
class UserSpec extends AbstractPromptSpec<UserSpec> implements PromptSpec<UserSpec> {
|
||||
interface AdvisorSpec {
|
||||
|
||||
private final List<Media> media = new ArrayList<>();
|
||||
AdvisorSpec param(String k, Object v);
|
||||
|
||||
public UserSpec media(Media... media) {
|
||||
this.media.addAll(Arrays.asList(media));
|
||||
return self();
|
||||
}
|
||||
AdvisorSpec params(Map<String, Object> p);
|
||||
|
||||
public UserSpec media(MimeType mimeType, URL url) {
|
||||
this.media.add(new Media(mimeType, url));
|
||||
return self();
|
||||
}
|
||||
AdvisorSpec advisors(RequestResponseAdvisor... advisors);
|
||||
|
||||
public UserSpec media(MimeType mimeType, Resource resource) {
|
||||
this.media.add(new Media(mimeType, resource));
|
||||
return self();
|
||||
}
|
||||
|
||||
protected List<Media> media() {
|
||||
return this.media;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected UserSpec self() {
|
||||
return this;
|
||||
}
|
||||
AdvisorSpec advisors(List<RequestResponseAdvisor> advisors);
|
||||
|
||||
}
|
||||
|
||||
class SystemSpec extends AbstractPromptSpec<SystemSpec> implements PromptSpec<SystemSpec> {
|
||||
interface CallResponseSpec {
|
||||
|
||||
@Override
|
||||
protected SystemSpec self() {
|
||||
return this;
|
||||
}
|
||||
<T> T entity(ParameterizedTypeReference<T> type);
|
||||
|
||||
<T> T entity(StructuredOutputConverter<T> structuredOutputConverter);
|
||||
|
||||
<T> T entity(Class<T> type);
|
||||
|
||||
ChatResponse chatResponse();
|
||||
|
||||
String content();
|
||||
|
||||
}
|
||||
|
||||
class ChatClientPromptRequest {
|
||||
interface StreamResponseSpec {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
Flux<ChatResponse> chatResponse();
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
public ChatClientPromptRequest(ChatModel chatModel, Prompt prompt) {
|
||||
this.chatModel = chatModel;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public ChatClientRequest.CallPromptResponseSpec call() {
|
||||
return new ChatClientRequest.CallPromptResponseSpec(this.chatModel, this.prompt);
|
||||
}
|
||||
|
||||
public ChatClientRequest.StreamPromptResponseSpec stream() {
|
||||
return new ChatClientRequest.StreamPromptResponseSpec((StreamingChatModel) this.chatModel, this.prompt);
|
||||
}
|
||||
Flux<String> content();
|
||||
|
||||
}
|
||||
|
||||
class AdvisorSpec {
|
||||
interface ChatClientPromptRequestSpec {
|
||||
|
||||
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
|
||||
CallPromptResponseSpec call();
|
||||
|
||||
private final Map<String, Object> params = new HashMap<>();
|
||||
|
||||
public AdvisorSpec param(String k, Object v) {
|
||||
this.params.put(k, v);
|
||||
return this;
|
||||
}
|
||||
|
||||
public AdvisorSpec params(Map<String, Object> p) {
|
||||
this.params.putAll(p);
|
||||
return this;
|
||||
}
|
||||
|
||||
public AdvisorSpec advisors(RequestResponseAdvisor... advisors) {
|
||||
this.advisors.addAll(List.of(advisors));
|
||||
return this;
|
||||
}
|
||||
|
||||
public AdvisorSpec advisors(List<RequestResponseAdvisor> advisors) {
|
||||
this.advisors.addAll(advisors);
|
||||
return this;
|
||||
}
|
||||
StreamPromptResponseSpec stream();
|
||||
|
||||
}
|
||||
|
||||
class ChatClientRequest {
|
||||
interface CallPromptResponseSpec {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
String content();
|
||||
|
||||
private String userText = "";
|
||||
List<String> contents();
|
||||
|
||||
private String systemText = "";
|
||||
ChatResponse chatResponse();
|
||||
|
||||
private ChatOptions chatOptions;
|
||||
}
|
||||
|
||||
private final List<Media> media = new ArrayList<>();
|
||||
interface StreamPromptResponseSpec {
|
||||
|
||||
private final List<String> functionNames = new ArrayList<>();
|
||||
Flux<ChatResponse> chatResponse();
|
||||
|
||||
private final List<FunctionCallback> functionCallbacks = new ArrayList<>();
|
||||
public Flux<String> content();
|
||||
|
||||
private final List<Message> messages = new ArrayList<>();
|
||||
}
|
||||
|
||||
private final Map<String, Object> userParams = new HashMap<>();
|
||||
|
||||
private final Map<String, Object> systemParams = new HashMap<>();
|
||||
|
||||
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
|
||||
|
||||
private final Map<String, Object> advisorParams = new HashMap<>();
|
||||
|
||||
/* copy constructor */
|
||||
ChatClientRequest(ChatClientRequest ccr) {
|
||||
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
|
||||
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
|
||||
}
|
||||
|
||||
public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Object> userParams,
|
||||
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
|
||||
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
|
||||
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
|
||||
|
||||
this.chatModel = chatModel;
|
||||
this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions();
|
||||
|
||||
this.userText = userText;
|
||||
this.userParams.putAll(userParams);
|
||||
this.systemText = systemText;
|
||||
this.systemParams.putAll(systemParams);
|
||||
|
||||
this.functionNames.addAll(functionNames);
|
||||
this.functionCallbacks.addAll(functionCallbacks);
|
||||
this.messages.addAll(messages);
|
||||
this.media.addAll(media);
|
||||
this.advisors.addAll(advisors);
|
||||
this.advisorParams.putAll(advisorParams);
|
||||
}
|
||||
interface ChatClientRequestSpec {
|
||||
|
||||
/**
|
||||
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
|
||||
* settings are replicated from this {@code ChatClientRequest}.
|
||||
*/
|
||||
public Builder mutate() {
|
||||
Builder builder = ChatClient.builder(chatModel)
|
||||
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
|
||||
.defaultUser(u -> u.text(this.userText)
|
||||
.params(this.userParams)
|
||||
.media(this.media.toArray(new Media[this.media.size()])))
|
||||
.defaultOptions(this.chatOptions)
|
||||
.defaultFunctions(StringUtils.toStringArray(this.functionNames));
|
||||
Builder mutate();
|
||||
|
||||
// workaround to set the missing fields.
|
||||
builder.defaultRequest.messages.addAll(this.messages);
|
||||
builder.defaultRequest.functionCallbacks.addAll(this.functionCallbacks);
|
||||
ChatClientRequestSpec advisors(Consumer<AdvisorSpec> consumer);
|
||||
|
||||
return builder;
|
||||
}
|
||||
ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors);
|
||||
|
||||
public ChatClientRequest advisors(Consumer<AdvisorSpec> consumer) {
|
||||
Assert.notNull(consumer, "the consumer must be non-null");
|
||||
var as = new AdvisorSpec();
|
||||
consumer.accept(as);
|
||||
this.advisorParams.putAll(as.params);
|
||||
this.advisors.addAll(as.advisors);
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec advisors(List<RequestResponseAdvisor> advisors);
|
||||
|
||||
public ChatClientRequest advisors(RequestResponseAdvisor... advisors) {
|
||||
Assert.notNull(advisors, "the advisors must be non-null");
|
||||
this.advisors.addAll(List.of(advisors));
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec messages(Message... messages);
|
||||
|
||||
public ChatClientRequest advisors(List<RequestResponseAdvisor> advisors) {
|
||||
Assert.notNull(advisors, "the advisors must be non-null");
|
||||
this.advisors.addAll(advisors);
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec messages(List<Message> messages);
|
||||
|
||||
public ChatClientRequest messages(Message... messages) {
|
||||
Assert.notNull(messages, "the messages must be non-null");
|
||||
this.messages.addAll(List.of(messages));
|
||||
return this;
|
||||
}
|
||||
<T extends ChatOptions> ChatClientRequestSpec options(T options);
|
||||
|
||||
public ChatClientRequest messages(List<Message> messages) {
|
||||
Assert.notNull(messages, "the messages must be non-null");
|
||||
this.messages.addAll(messages);
|
||||
return this;
|
||||
}
|
||||
<I, O> ChatClientRequestSpec function(String name, String description,
|
||||
java.util.function.Function<I, O> function);
|
||||
|
||||
public <T extends ChatOptions> ChatClientRequest options(T options) {
|
||||
Assert.notNull(options, "the options must be non-null");
|
||||
this.chatOptions = options;
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec functions(String... functionBeanNames);
|
||||
|
||||
public <I, O> ChatClientRequest function(String name, String description,
|
||||
java.util.function.Function<I, O> function) {
|
||||
ChatClientRequestSpec system(String text);
|
||||
|
||||
Assert.hasText(name, "the name must be non-null and non-empty");
|
||||
Assert.hasText(description, "the description must be non-null and non-empty");
|
||||
Assert.notNull(function, "the function must be non-null");
|
||||
ChatClientRequestSpec system(Resource textResource, Charset charset);
|
||||
|
||||
var fcw = FunctionCallbackWrapper.builder(function)
|
||||
.withDescription(description)
|
||||
.withName(name)
|
||||
.withResponseConverter(Object::toString)
|
||||
.build();
|
||||
this.functionCallbacks.add(fcw);
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec system(Resource text);
|
||||
|
||||
public ChatClientRequest functions(String... functionBeanNames) {
|
||||
Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null");
|
||||
this.functionNames.addAll(List.of(functionBeanNames));
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer);
|
||||
|
||||
public ChatClientRequest system(String text) {
|
||||
Assert.notNull(text, "the text must be non-null");
|
||||
this.systemText = text;
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec user(String text);
|
||||
|
||||
public ChatClientRequest system(Resource textResource, Charset charset) {
|
||||
ChatClientRequestSpec user(Resource text, Charset charset);
|
||||
|
||||
Assert.notNull(textResource, "the text resource must be non-null");
|
||||
Assert.notNull(charset, "the charset must be non-null");
|
||||
ChatClientRequestSpec user(Resource text);
|
||||
|
||||
try {
|
||||
this.systemText = textResource.getContentAsString(charset);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
|
||||
|
||||
public ChatClientRequest system(Resource text) {
|
||||
Assert.notNull(text, "the text resource must be non-null");
|
||||
return this.system(text, Charset.defaultCharset());
|
||||
}
|
||||
// ChatClientRequestSpec adviseOnRequest(ChatClientRequestSpec inputRequest,
|
||||
// Map<String, Object> context);
|
||||
|
||||
public ChatClientRequest system(Consumer<SystemSpec> consumer) {
|
||||
CallResponseSpec call();
|
||||
|
||||
Assert.notNull(consumer, "the consumer must be non-null");
|
||||
|
||||
var ss = new SystemSpec();
|
||||
consumer.accept(ss);
|
||||
this.systemText = StringUtils.hasText(ss.text()) ? ss.text() : this.systemText;
|
||||
this.systemParams.putAll(ss.params());
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequest user(String text) {
|
||||
Assert.notNull(text, "the text must be non-null");
|
||||
this.userText = text;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequest user(Resource text, Charset charset) {
|
||||
|
||||
Assert.notNull(text, "the text resource must be non-null");
|
||||
Assert.notNull(charset, "the charset must be non-null");
|
||||
|
||||
try {
|
||||
this.userText = text.getContentAsString(charset);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequest user(Resource text) {
|
||||
Assert.notNull(text, "the text resource must be non-null");
|
||||
return this.user(text, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
public ChatClientRequest user(Consumer<UserSpec> consumer) {
|
||||
Assert.notNull(consumer, "the consumer must be non-null");
|
||||
|
||||
var us = new UserSpec();
|
||||
consumer.accept(us);
|
||||
this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText;
|
||||
this.userParams.putAll(us.params());
|
||||
this.media.addAll(us.media());
|
||||
return this;
|
||||
}
|
||||
|
||||
public static class StreamPromptResponseSpec {
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
private final StreamingChatModel chatModel;
|
||||
|
||||
public StreamPromptResponseSpec(StreamingChatModel streamingChatModel, Prompt prompt) {
|
||||
this.chatModel = streamingChatModel;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public Flux<ChatResponse> chatResponse() {
|
||||
return doGetFluxChatResponse(this.prompt);
|
||||
}
|
||||
|
||||
private Flux<ChatResponse> doGetFluxChatResponse(Prompt prompt) {
|
||||
return this.chatModel.stream(prompt);
|
||||
}
|
||||
|
||||
public Flux<String> content() {
|
||||
return doGetFluxChatResponse(this.prompt).map(r -> {
|
||||
if (r.getResult() == null || r.getResult().getOutput() == null
|
||||
|| r.getResult().getOutput().getContent() == null) {
|
||||
return "";
|
||||
}
|
||||
return r.getResult().getOutput().getContent();
|
||||
}).filter(v -> StringUtils.hasText(v));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class CallPromptResponseSpec {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
public CallPromptResponseSpec(ChatModel chatModel, Prompt prompt) {
|
||||
this.chatModel = chatModel;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public String content() {
|
||||
return doGetChatResponse(this.prompt).getResult().getOutput().getContent();
|
||||
}
|
||||
|
||||
public List<String> contents() {
|
||||
return doGetChatResponse(this.prompt).getResults()
|
||||
.stream()
|
||||
.map(r -> r.getOutput().getContent())
|
||||
.toList();
|
||||
}
|
||||
|
||||
public ChatResponse chatResponse() {
|
||||
return doGetChatResponse(this.prompt);
|
||||
}
|
||||
|
||||
private ChatResponse doGetChatResponse(Prompt prompt) {
|
||||
return chatModel.call(prompt);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static ChatClientRequest adviseOnRequest(ChatClientRequest inputRequest, Map<String, Object> context) {
|
||||
|
||||
ChatClientRequest advisedRequest = inputRequest;
|
||||
|
||||
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
|
||||
AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText,
|
||||
inputRequest.systemText, inputRequest.chatOptions, inputRequest.media,
|
||||
inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages,
|
||||
inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors,
|
||||
inputRequest.advisorParams);
|
||||
|
||||
// apply the advisors onRequest
|
||||
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
|
||||
for (RequestResponseAdvisor advisor : currentAdvisors) {
|
||||
adviseRequest = advisor.adviseRequest(adviseRequest, context);
|
||||
}
|
||||
|
||||
advisedRequest = new ChatClientRequest(adviseRequest.chatModel(), adviseRequest.userText(),
|
||||
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
|
||||
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
|
||||
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
|
||||
adviseRequest.advisorParams());
|
||||
}
|
||||
|
||||
return advisedRequest;
|
||||
}
|
||||
|
||||
public static class CallResponseSpec {
|
||||
|
||||
private final ChatClientRequest request;
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
public CallResponseSpec(ChatModel chatModel, ChatClientRequest request) {
|
||||
this.chatModel = chatModel;
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
public <T> T entity(ParameterizedTypeReference<T> type) {
|
||||
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
|
||||
}
|
||||
|
||||
public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) {
|
||||
return doSingleWithBeanOutputConverter(structuredOutputConverter);
|
||||
}
|
||||
|
||||
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> boc) {
|
||||
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
|
||||
var stringResponse = chatResponse.getResult().getOutput().getContent();
|
||||
return boc.convert(stringResponse);
|
||||
}
|
||||
|
||||
public <T> T entity(Class<T> type) {
|
||||
Assert.notNull(type, "the class must be non-null");
|
||||
var boc = new BeanOutputConverter<T>(type);
|
||||
return doSingleWithBeanOutputConverter(boc);
|
||||
}
|
||||
|
||||
private ChatResponse doGetChatResponse() {
|
||||
return this.doGetChatResponse(this.request, "");
|
||||
}
|
||||
|
||||
private ChatResponse doGetChatResponse(ChatClientRequest inputRequest, String formatParam) {
|
||||
|
||||
Map<String, Object> context = new ConcurrentHashMap<>();
|
||||
context.putAll(inputRequest.advisorParams);
|
||||
ChatClientRequest advisedRequest = adviseOnRequest(inputRequest, context);
|
||||
|
||||
var processedUserText = StringUtils.hasText(formatParam)
|
||||
? advisedRequest.userText + System.lineSeparator() + "{spring_ai_soc_format}"
|
||||
: advisedRequest.userText;
|
||||
|
||||
Map<String, Object> userParams = new HashMap<>(advisedRequest.userParams);
|
||||
if (StringUtils.hasText(formatParam)) {
|
||||
userParams.put("spring_ai_soc_format", formatParam);
|
||||
}
|
||||
|
||||
var messages = new ArrayList<Message>(advisedRequest.messages);
|
||||
var textsAreValid = (StringUtils.hasText(processedUserText)
|
||||
|| StringUtils.hasText(advisedRequest.systemText));
|
||||
if (textsAreValid) {
|
||||
if (StringUtils.hasText(advisedRequest.systemText) || !advisedRequest.systemParams.isEmpty()) {
|
||||
var systemMessage = new SystemMessage(
|
||||
new PromptTemplate(advisedRequest.systemText, advisedRequest.systemParams).render());
|
||||
messages.add(systemMessage);
|
||||
}
|
||||
UserMessage userMessage = null;
|
||||
if (!CollectionUtils.isEmpty(userParams)) {
|
||||
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
|
||||
advisedRequest.media);
|
||||
}
|
||||
else {
|
||||
userMessage = new UserMessage(processedUserText, advisedRequest.media);
|
||||
}
|
||||
messages.add(userMessage);
|
||||
}
|
||||
|
||||
if (advisedRequest.chatOptions instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
if (!advisedRequest.functionNames.isEmpty()) {
|
||||
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames));
|
||||
}
|
||||
if (!advisedRequest.functionCallbacks.isEmpty()) {
|
||||
functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks);
|
||||
}
|
||||
}
|
||||
var prompt = new Prompt(messages, advisedRequest.chatOptions);
|
||||
var chatResponse = this.chatModel.call(prompt);
|
||||
|
||||
ChatResponse advisedResponse = chatResponse;
|
||||
// apply the advisors on response
|
||||
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
|
||||
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
|
||||
for (RequestResponseAdvisor advisor : currentAdvisors) {
|
||||
advisedResponse = advisor.adviseResponse(advisedResponse, context);
|
||||
}
|
||||
}
|
||||
|
||||
return advisedResponse;
|
||||
}
|
||||
|
||||
public ChatResponse chatResponse() {
|
||||
return doGetChatResponse();
|
||||
}
|
||||
|
||||
public String content() {
|
||||
return doGetChatResponse().getResult().getOutput().getContent();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class StreamResponseSpec {
|
||||
|
||||
private final ChatClientRequest request;
|
||||
|
||||
private final StreamingChatModel chatModel;
|
||||
|
||||
public StreamResponseSpec(StreamingChatModel streamingChatModel, ChatClientRequest request) {
|
||||
this.chatModel = streamingChatModel;
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
private Flux<ChatResponse> doGetFluxChatResponse(ChatClientRequest inputRequest) {
|
||||
|
||||
Map<String, Object> context = new ConcurrentHashMap<>();
|
||||
context.putAll(inputRequest.advisorParams);
|
||||
ChatClientRequest advisedRequest = adviseOnRequest(inputRequest, context);
|
||||
|
||||
String processedUserText = advisedRequest.userText;
|
||||
Map<String, Object> userParams = new HashMap<>(advisedRequest.userParams);
|
||||
|
||||
var messages = new ArrayList<Message>(advisedRequest.messages);
|
||||
var textsAreValid = (StringUtils.hasText(processedUserText)
|
||||
|| StringUtils.hasText(advisedRequest.systemText));
|
||||
if (textsAreValid) {
|
||||
UserMessage userMessage = null;
|
||||
if (!CollectionUtils.isEmpty(userParams)) {
|
||||
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
|
||||
advisedRequest.media);
|
||||
}
|
||||
else {
|
||||
userMessage = new UserMessage(processedUserText, advisedRequest.media);
|
||||
}
|
||||
if (StringUtils.hasText(advisedRequest.systemText) || !advisedRequest.systemParams.isEmpty()) {
|
||||
var systemMessage = new SystemMessage(
|
||||
new PromptTemplate(advisedRequest.systemText, advisedRequest.systemParams).render());
|
||||
messages.add(systemMessage);
|
||||
}
|
||||
messages.add(userMessage);
|
||||
}
|
||||
|
||||
if (advisedRequest.chatOptions instanceof
|
||||
|
||||
FunctionCallingOptions functionCallingOptions) {
|
||||
if (!advisedRequest.functionNames.isEmpty()) {
|
||||
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.functionNames));
|
||||
}
|
||||
if (!advisedRequest.functionCallbacks.isEmpty()) {
|
||||
functionCallingOptions.setFunctionCallbacks(advisedRequest.functionCallbacks);
|
||||
}
|
||||
}
|
||||
var prompt = new Prompt(messages, advisedRequest.chatOptions);
|
||||
|
||||
var fluxChatResponse = this.chatModel.stream(prompt);
|
||||
|
||||
Flux<ChatResponse> advisedResponse = fluxChatResponse;
|
||||
// apply the advisors on response
|
||||
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
|
||||
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
|
||||
for (RequestResponseAdvisor advisor : currentAdvisors) {
|
||||
advisedResponse = advisor.adviseResponse(advisedResponse, context);
|
||||
}
|
||||
}
|
||||
|
||||
return advisedResponse;
|
||||
}
|
||||
|
||||
public Flux<ChatResponse> chatResponse() {
|
||||
return doGetFluxChatResponse(this.request);
|
||||
}
|
||||
|
||||
public Flux<String> content() {
|
||||
return doGetFluxChatResponse(this.request).map(r -> {
|
||||
if (r.getResult() == null || r.getResult().getOutput() == null
|
||||
|| r.getResult().getOutput().getContent() == null) {
|
||||
return "";
|
||||
}
|
||||
return r.getResult().getOutput().getContent();
|
||||
}).filter(v -> StringUtils.hasText(v));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public CallResponseSpec call() {
|
||||
return new CallResponseSpec(this.chatModel, this);
|
||||
}
|
||||
|
||||
public StreamResponseSpec stream() {
|
||||
return new StreamResponseSpec((StreamingChatModel) this.chatModel, this);
|
||||
}
|
||||
StreamResponseSpec stream();
|
||||
|
||||
}
|
||||
|
||||
class Builder {
|
||||
/**
|
||||
* A mutable builder for creating a {@link ChatClient}.
|
||||
*/
|
||||
interface Builder {
|
||||
|
||||
private final ChatClientRequest defaultRequest;
|
||||
Builder defaultAdvisors(RequestResponseAdvisor... advisor);
|
||||
|
||||
private final ChatModel chatModel;
|
||||
Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer);
|
||||
|
||||
Builder(ChatModel chatModel) {
|
||||
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
|
||||
this.chatModel = chatModel;
|
||||
this.defaultRequest = new ChatClientRequest(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
|
||||
List.of(), List.of(), null, List.of(), Map.of());
|
||||
}
|
||||
Builder defaultAdvisors(List<RequestResponseAdvisor> advisors);
|
||||
|
||||
public Builder defaultAdvisors(RequestResponseAdvisor... advisor) {
|
||||
this.defaultRequest.advisors(advisor);
|
||||
return this;
|
||||
}
|
||||
ChatClient build();
|
||||
|
||||
public Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer) {
|
||||
this.defaultRequest.advisors(advisorSpecConsumer);
|
||||
return this;
|
||||
}
|
||||
Builder defaultOptions(ChatOptions chatOptions);
|
||||
|
||||
public Builder defaultAdvisors(List<RequestResponseAdvisor> advisors) {
|
||||
this.defaultRequest.advisors(advisors);
|
||||
return this;
|
||||
}
|
||||
Builder defaultUser(String text);
|
||||
|
||||
public ChatClient build() {
|
||||
return new DefaultChatClient(this.chatModel, this.defaultRequest);
|
||||
}
|
||||
Builder defaultUser(Resource text, Charset charset);
|
||||
|
||||
public Builder defaultOptions(ChatOptions chatOptions) {
|
||||
this.defaultRequest.options(chatOptions);
|
||||
return this;
|
||||
}
|
||||
Builder defaultUser(Resource text);
|
||||
|
||||
public Builder defaultUser(String text) {
|
||||
this.defaultRequest.user(text);
|
||||
return this;
|
||||
}
|
||||
Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer);
|
||||
|
||||
public Builder defaultUser(Resource text, Charset charset) {
|
||||
try {
|
||||
this.defaultRequest.user(text.getContentAsString(charset));
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
Builder defaultSystem(String text);
|
||||
|
||||
public Builder defaultUser(Resource text) {
|
||||
return this.defaultUser(text, Charset.defaultCharset());
|
||||
}
|
||||
Builder defaultSystem(Resource text, Charset charset);
|
||||
|
||||
public Builder defaultUser(Consumer<UserSpec> userSpecConsumer) {
|
||||
this.defaultRequest.user(userSpecConsumer);
|
||||
return this;
|
||||
}
|
||||
Builder defaultSystem(Resource text);
|
||||
|
||||
public Builder defaultSystem(String text) {
|
||||
this.defaultRequest.system(text);
|
||||
return this;
|
||||
}
|
||||
Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);
|
||||
|
||||
public Builder defaultSystem(Resource text, Charset charset) {
|
||||
try {
|
||||
this.defaultRequest.system(text.getContentAsString(charset));
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
<I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function);
|
||||
|
||||
public Builder defaultSystem(Resource text) {
|
||||
return this.defaultSystem(text, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
public Builder defaultSystem(Consumer<SystemSpec> systemSpecConsumer) {
|
||||
this.defaultRequest.system(systemSpecConsumer);
|
||||
return this;
|
||||
}
|
||||
|
||||
public <I, O> Builder defaultFunction(String name, String description,
|
||||
java.util.function.Function<I, O> function) {
|
||||
this.defaultRequest.function(name, description, function);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultFunctions(String... functionNames) {
|
||||
this.defaultRequest.functions(functionNames);
|
||||
return this;
|
||||
}
|
||||
Builder defaultFunctions(String... functionNames);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -1,42 +1,74 @@
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.messages.Media;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.StructuredOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallbackWrapper;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.MimeType;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* The default implementation of {@link ChatClient} as created by the
|
||||
* {@link ChatClient.Builder#build()} } method.
|
||||
* {@link Builder#build()} } method.
|
||||
*
|
||||
* @author Mark Pollack
|
||||
* @author Christian Tzolov
|
||||
* @author Josh Long
|
||||
* @author Arjen Poutsma
|
||||
* @since 1.0.0 M1
|
||||
* @since 1.0.0
|
||||
*/
|
||||
class DefaultChatClient implements ChatClient {
|
||||
public class DefaultChatClient implements ChatClient {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
private final ChatClientRequest defaultChatClientRequest;
|
||||
private final DefaultChatClientRequestSpec defaultChatClientRequest;
|
||||
|
||||
public DefaultChatClient(ChatModel chatModel, ChatClientRequest defaultChatClientRequest) {
|
||||
public DefaultChatClient(ChatModel chatModel, DefaultChatClientRequestSpec defaultChatClientRequest) {
|
||||
this.chatModel = chatModel;
|
||||
this.defaultChatClientRequest = defaultChatClientRequest;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientRequest prompt() {
|
||||
return new ChatClientRequest(this.defaultChatClientRequest);
|
||||
public ChatClientRequestSpec prompt() {
|
||||
return new DefaultChatClientRequestSpec(this.defaultChatClientRequest);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientPromptRequest prompt(Prompt prompt) {
|
||||
return new ChatClientPromptRequest(this.chatModel, prompt);
|
||||
public ChatClientPromptRequestSpec prompt(Prompt prompt) {
|
||||
return new DefaultChatClientPromptRequestSpec(this.chatModel, prompt);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
|
||||
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient} whose
|
||||
* settings are replicated from this {@code ChatClientRequest}.
|
||||
*/
|
||||
@Override
|
||||
@@ -44,6 +76,732 @@ class DefaultChatClient implements ChatClient {
|
||||
return this.defaultChatClientRequest.mutate();
|
||||
}
|
||||
|
||||
public static class DefaultPromptUserSpec implements PromptUserSpec {
|
||||
|
||||
private String text = "";
|
||||
|
||||
private final Map<String, Object> params = new HashMap<>();
|
||||
|
||||
private final List<Media> media = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public PromptUserSpec media(Media... media) {
|
||||
this.media.addAll(Arrays.asList(media));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptUserSpec media(MimeType mimeType, URL url) {
|
||||
this.media.add(new Media(mimeType, url));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptUserSpec media(MimeType mimeType, Resource resource) {
|
||||
this.media.add(new Media(mimeType, resource));
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptUserSpec text(String text) {
|
||||
this.text = text;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptUserSpec text(Resource text, Charset charset) {
|
||||
try {
|
||||
this.text(text.getContentAsString(charset));
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptUserSpec text(Resource text) {
|
||||
this.text(text, Charset.defaultCharset());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptUserSpec param(String k, Object v) {
|
||||
this.params.put(k, v);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptUserSpec params(Map<String, Object> p) {
|
||||
this.params.putAll(p);
|
||||
return this;
|
||||
}
|
||||
|
||||
protected String text() {
|
||||
return this.text;
|
||||
}
|
||||
|
||||
protected Map<String, Object> params() {
|
||||
return this.params;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Media> media() {
|
||||
return this.media;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultPromptSystemSpec implements PromptSystemSpec {
|
||||
|
||||
private String text = "";
|
||||
|
||||
private final Map<String, Object> params = new HashMap<>();
|
||||
|
||||
@Override
|
||||
public PromptSystemSpec text(String text) {
|
||||
this.text = text;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptSystemSpec text(Resource text, Charset charset) {
|
||||
try {
|
||||
this.text(text.getContentAsString(charset));
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptSystemSpec text(Resource text) {
|
||||
this.text(text, Charset.defaultCharset());
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptSystemSpec param(String k, Object v) {
|
||||
this.params.put(k, v);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PromptSystemSpec params(Map<String, Object> p) {
|
||||
this.params.putAll(p);
|
||||
return this;
|
||||
}
|
||||
|
||||
protected String text() {
|
||||
return this.text;
|
||||
}
|
||||
|
||||
protected Map<String, Object> params() {
|
||||
return this.params;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultAdvisorSpec implements AdvisorSpec {
|
||||
|
||||
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
|
||||
|
||||
private final Map<String, Object> params = new HashMap<>();
|
||||
|
||||
public AdvisorSpec param(String k, Object v) {
|
||||
this.params.put(k, v);
|
||||
return this;
|
||||
}
|
||||
|
||||
public AdvisorSpec params(Map<String, Object> p) {
|
||||
this.params.putAll(p);
|
||||
return this;
|
||||
}
|
||||
|
||||
public AdvisorSpec advisors(RequestResponseAdvisor... advisors) {
|
||||
this.advisors.addAll(List.of(advisors));
|
||||
return this;
|
||||
}
|
||||
|
||||
public AdvisorSpec advisors(List<RequestResponseAdvisor> advisors) {
|
||||
this.advisors.addAll(advisors);
|
||||
return this;
|
||||
}
|
||||
|
||||
public List<RequestResponseAdvisor> getAdvisors() {
|
||||
return advisors;
|
||||
}
|
||||
|
||||
public Map<String, Object> getParams() {
|
||||
return params;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultCallResponseSpec implements CallResponseSpec {
|
||||
|
||||
private final DefaultChatClientRequestSpec request;
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
public DefaultCallResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) {
|
||||
this.chatModel = chatModel;
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
public <T> T entity(ParameterizedTypeReference<T> type) {
|
||||
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
|
||||
}
|
||||
|
||||
public <T> T entity(StructuredOutputConverter<T> structuredOutputConverter) {
|
||||
return doSingleWithBeanOutputConverter(structuredOutputConverter);
|
||||
}
|
||||
|
||||
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> boc) {
|
||||
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
|
||||
var stringResponse = chatResponse.getResult().getOutput().getContent();
|
||||
return boc.convert(stringResponse);
|
||||
}
|
||||
|
||||
public <T> T entity(Class<T> type) {
|
||||
Assert.notNull(type, "the class must be non-null");
|
||||
var boc = new BeanOutputConverter<T>(type);
|
||||
return doSingleWithBeanOutputConverter(boc);
|
||||
}
|
||||
|
||||
private ChatResponse doGetChatResponse() {
|
||||
return this.doGetChatResponse(this.request, "");
|
||||
}
|
||||
|
||||
private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequest, String formatParam) {
|
||||
|
||||
Map<String, Object> context = new ConcurrentHashMap<>();
|
||||
context.putAll(inputRequest.getAdvisorParams());
|
||||
DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest,
|
||||
context);
|
||||
|
||||
var processedUserText = StringUtils.hasText(formatParam)
|
||||
? advisedRequest.getUserText() + System.lineSeparator() + "{spring_ai_soc_format}"
|
||||
: advisedRequest.getUserText();
|
||||
|
||||
Map<String, Object> userParams = new HashMap<>(advisedRequest.getUserParams());
|
||||
if (StringUtils.hasText(formatParam)) {
|
||||
userParams.put("spring_ai_soc_format", formatParam);
|
||||
}
|
||||
|
||||
var messages = new ArrayList<Message>(advisedRequest.getMessages());
|
||||
var textsAreValid = (StringUtils.hasText(processedUserText)
|
||||
|| StringUtils.hasText(advisedRequest.getSystemText()));
|
||||
if (textsAreValid) {
|
||||
if (StringUtils.hasText(advisedRequest.getSystemText())
|
||||
|| !advisedRequest.getSystemParams().isEmpty()) {
|
||||
var systemMessage = new SystemMessage(
|
||||
new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams())
|
||||
.render());
|
||||
messages.add(systemMessage);
|
||||
}
|
||||
UserMessage userMessage = null;
|
||||
if (!CollectionUtils.isEmpty(userParams)) {
|
||||
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
|
||||
advisedRequest.getMedia());
|
||||
}
|
||||
else {
|
||||
userMessage = new UserMessage(processedUserText, advisedRequest.getMedia());
|
||||
}
|
||||
messages.add(userMessage);
|
||||
}
|
||||
|
||||
if (advisedRequest.getChatOptions() instanceof FunctionCallingOptions functionCallingOptions) {
|
||||
if (!advisedRequest.getFunctionNames().isEmpty()) {
|
||||
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames()));
|
||||
}
|
||||
if (!advisedRequest.getFunctionCallbacks().isEmpty()) {
|
||||
functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks());
|
||||
}
|
||||
}
|
||||
var prompt = new Prompt(messages, advisedRequest.getChatOptions());
|
||||
var chatResponse = this.chatModel.call(prompt);
|
||||
|
||||
ChatResponse advisedResponse = chatResponse;
|
||||
// apply the advisors on response
|
||||
if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) {
|
||||
var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors());
|
||||
for (RequestResponseAdvisor advisor : currentAdvisors) {
|
||||
advisedResponse = advisor.adviseResponse(advisedResponse, context);
|
||||
}
|
||||
}
|
||||
|
||||
return advisedResponse;
|
||||
}
|
||||
|
||||
public ChatResponse chatResponse() {
|
||||
return doGetChatResponse();
|
||||
}
|
||||
|
||||
public String content() {
|
||||
return doGetChatResponse().getResult().getOutput().getContent();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultStreamResponseSpec implements StreamResponseSpec {
|
||||
|
||||
private final DefaultChatClientRequestSpec request;
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
public DefaultStreamResponseSpec(ChatModel chatModel, DefaultChatClientRequestSpec request) {
|
||||
this.chatModel = chatModel;
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
private Flux<ChatResponse> doGetFluxChatResponse(DefaultChatClientRequestSpec inputRequest) {
|
||||
|
||||
Map<String, Object> context = new ConcurrentHashMap<>();
|
||||
context.putAll(inputRequest.getAdvisorParams());
|
||||
DefaultChatClientRequestSpec advisedRequest = DefaultChatClientRequestSpec.adviseOnRequest(inputRequest,
|
||||
context);
|
||||
|
||||
String processedUserText = advisedRequest.getUserText();
|
||||
Map<String, Object> userParams = new HashMap<>(advisedRequest.getUserParams());
|
||||
|
||||
var messages = new ArrayList<Message>(advisedRequest.getMessages());
|
||||
var textsAreValid = (StringUtils.hasText(processedUserText)
|
||||
|| StringUtils.hasText(advisedRequest.getSystemText()));
|
||||
if (textsAreValid) {
|
||||
UserMessage userMessage = null;
|
||||
if (!CollectionUtils.isEmpty(userParams)) {
|
||||
userMessage = new UserMessage(new PromptTemplate(processedUserText, userParams).render(),
|
||||
advisedRequest.getMedia());
|
||||
}
|
||||
else {
|
||||
userMessage = new UserMessage(processedUserText, advisedRequest.getMedia());
|
||||
}
|
||||
if (StringUtils.hasText(advisedRequest.getSystemText())
|
||||
|| !advisedRequest.getSystemParams().isEmpty()) {
|
||||
var systemMessage = new SystemMessage(
|
||||
new PromptTemplate(advisedRequest.getSystemText(), advisedRequest.getSystemParams())
|
||||
.render());
|
||||
messages.add(systemMessage);
|
||||
}
|
||||
messages.add(userMessage);
|
||||
}
|
||||
|
||||
if (advisedRequest.getChatOptions() instanceof
|
||||
|
||||
FunctionCallingOptions functionCallingOptions) {
|
||||
if (!advisedRequest.getFunctionNames().isEmpty()) {
|
||||
functionCallingOptions.setFunctions(new HashSet<>(advisedRequest.getFunctionNames()));
|
||||
}
|
||||
if (!advisedRequest.getFunctionCallbacks().isEmpty()) {
|
||||
functionCallingOptions.setFunctionCallbacks(advisedRequest.getFunctionCallbacks());
|
||||
}
|
||||
}
|
||||
var prompt = new Prompt(messages, advisedRequest.getChatOptions());
|
||||
|
||||
var fluxChatResponse = this.chatModel.stream(prompt);
|
||||
|
||||
Flux<ChatResponse> advisedResponse = fluxChatResponse;
|
||||
// apply the advisors on response
|
||||
if (!CollectionUtils.isEmpty(inputRequest.getAdvisors())) {
|
||||
var currentAdvisors = new ArrayList<>(inputRequest.getAdvisors());
|
||||
for (RequestResponseAdvisor advisor : currentAdvisors) {
|
||||
advisedResponse = advisor.adviseResponse(advisedResponse, context);
|
||||
}
|
||||
}
|
||||
|
||||
return advisedResponse;
|
||||
}
|
||||
|
||||
public Flux<ChatResponse> chatResponse() {
|
||||
return doGetFluxChatResponse(this.request);
|
||||
}
|
||||
|
||||
public Flux<String> content() {
|
||||
return doGetFluxChatResponse(this.request).map(r -> {
|
||||
if (r.getResult() == null || r.getResult().getOutput() == null
|
||||
|| r.getResult().getOutput().getContent() == null) {
|
||||
return "";
|
||||
}
|
||||
return r.getResult().getOutput().getContent();
|
||||
}).filter(v -> StringUtils.hasText(v));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultChatClientRequestSpec implements ChatClientRequestSpec {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
private String userText = "";
|
||||
|
||||
private String systemText = "";
|
||||
|
||||
private ChatOptions chatOptions;
|
||||
|
||||
private final List<Media> media = new ArrayList<>();
|
||||
|
||||
private final List<String> functionNames = new ArrayList<>();
|
||||
|
||||
private final List<FunctionCallback> functionCallbacks = new ArrayList<>();
|
||||
|
||||
private final List<Message> messages = new ArrayList<>();
|
||||
|
||||
private final Map<String, Object> userParams = new HashMap<>();
|
||||
|
||||
private final Map<String, Object> systemParams = new HashMap<>();
|
||||
|
||||
private List<RequestResponseAdvisor> advisors = new ArrayList<>();
|
||||
|
||||
private final Map<String, Object> advisorParams = new HashMap<>();
|
||||
|
||||
public String getUserText() {
|
||||
return userText;
|
||||
}
|
||||
|
||||
public Map<String, Object> getUserParams() {
|
||||
return userParams;
|
||||
}
|
||||
|
||||
public String getSystemText() {
|
||||
return systemText;
|
||||
}
|
||||
|
||||
public Map<String, Object> getSystemParams() {
|
||||
return systemParams;
|
||||
}
|
||||
|
||||
public ChatOptions getChatOptions() {
|
||||
return chatOptions;
|
||||
}
|
||||
|
||||
public List<RequestResponseAdvisor> getAdvisors() {
|
||||
return advisors;
|
||||
}
|
||||
|
||||
public Map<String, Object> getAdvisorParams() {
|
||||
return advisorParams;
|
||||
}
|
||||
|
||||
public List<Message> getMessages() {
|
||||
return messages;
|
||||
}
|
||||
|
||||
public List<Media> getMedia() {
|
||||
return media;
|
||||
}
|
||||
|
||||
public List<String> getFunctionNames() {
|
||||
return this.functionNames;
|
||||
}
|
||||
|
||||
public List<FunctionCallback> getFunctionCallbacks() {
|
||||
return functionCallbacks;
|
||||
}
|
||||
|
||||
/* copy constructor */
|
||||
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
|
||||
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
|
||||
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
|
||||
}
|
||||
|
||||
public DefaultChatClientRequestSpec(ChatModel chatModel, String userText, Map<String, Object> userParams,
|
||||
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
|
||||
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
|
||||
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
|
||||
|
||||
this.chatModel = chatModel;
|
||||
this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions();
|
||||
|
||||
this.userText = userText;
|
||||
this.userParams.putAll(userParams);
|
||||
this.systemText = systemText;
|
||||
this.systemParams.putAll(systemParams);
|
||||
|
||||
this.functionNames.addAll(functionNames);
|
||||
this.functionCallbacks.addAll(functionCallbacks);
|
||||
this.messages.addAll(messages);
|
||||
this.media.addAll(media);
|
||||
this.advisors.addAll(advisors);
|
||||
this.advisorParams.putAll(advisorParams);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose
|
||||
* settings are replicated from this {@code ChatClientRequest}.
|
||||
*/
|
||||
public Builder mutate() {
|
||||
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient.builder(chatModel)
|
||||
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
|
||||
.defaultUser(u -> u.text(this.userText)
|
||||
.params(this.userParams)
|
||||
.media(this.media.toArray(new Media[this.media.size()])))
|
||||
.defaultOptions(this.chatOptions)
|
||||
.defaultFunctions(StringUtils.toStringArray(this.functionNames));
|
||||
|
||||
// workaround to set the missing fields.
|
||||
builder.defaultRequest.getMessages().addAll(this.messages);
|
||||
builder.defaultRequest.getFunctionCallbacks().addAll(this.functionCallbacks);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec advisors(Consumer<ChatClient.AdvisorSpec> consumer) {
|
||||
Assert.notNull(consumer, "the consumer must be non-null");
|
||||
var as = new DefaultAdvisorSpec();
|
||||
consumer.accept(as);
|
||||
this.advisorParams.putAll(as.getParams());
|
||||
this.advisors.addAll(as.getAdvisors());
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec advisors(RequestResponseAdvisor... advisors) {
|
||||
Assert.notNull(advisors, "the advisors must be non-null");
|
||||
this.advisors.addAll(List.of(advisors));
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec advisors(List<RequestResponseAdvisor> advisors) {
|
||||
Assert.notNull(advisors, "the advisors must be non-null");
|
||||
this.advisors.addAll(advisors);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec messages(Message... messages) {
|
||||
Assert.notNull(messages, "the messages must be non-null");
|
||||
this.messages.addAll(List.of(messages));
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec messages(List<Message> messages) {
|
||||
Assert.notNull(messages, "the messages must be non-null");
|
||||
this.messages.addAll(messages);
|
||||
return this;
|
||||
}
|
||||
|
||||
public <T extends ChatOptions> ChatClientRequestSpec options(T options) {
|
||||
Assert.notNull(options, "the options must be non-null");
|
||||
this.chatOptions = options;
|
||||
return this;
|
||||
}
|
||||
|
||||
public <I, O> ChatClientRequestSpec function(String name, String description,
|
||||
java.util.function.Function<I, O> function) {
|
||||
|
||||
Assert.hasText(name, "the name must be non-null and non-empty");
|
||||
Assert.hasText(description, "the description must be non-null and non-empty");
|
||||
Assert.notNull(function, "the function must be non-null");
|
||||
|
||||
var fcw = FunctionCallbackWrapper.builder(function)
|
||||
.withDescription(description)
|
||||
.withName(name)
|
||||
.withResponseConverter(Object::toString)
|
||||
.build();
|
||||
this.functionCallbacks.add(fcw);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec functions(String... functionBeanNames) {
|
||||
Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null");
|
||||
this.functionNames.addAll(List.of(functionBeanNames));
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec system(String text) {
|
||||
Assert.notNull(text, "the text must be non-null");
|
||||
this.systemText = text;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec system(Resource textResource, Charset charset) {
|
||||
|
||||
Assert.notNull(textResource, "the text resource must be non-null");
|
||||
Assert.notNull(charset, "the charset must be non-null");
|
||||
|
||||
try {
|
||||
this.systemText = textResource.getContentAsString(charset);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec system(Resource text) {
|
||||
Assert.notNull(text, "the text resource must be non-null");
|
||||
return this.system(text, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer) {
|
||||
|
||||
Assert.notNull(consumer, "the consumer must be non-null");
|
||||
|
||||
var ss = new DefaultPromptSystemSpec();
|
||||
consumer.accept(ss);
|
||||
this.systemText = StringUtils.hasText(ss.text()) ? ss.text() : this.systemText;
|
||||
this.systemParams.putAll(ss.params());
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec user(String text) {
|
||||
Assert.notNull(text, "the text must be non-null");
|
||||
this.userText = text;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec user(Resource text, Charset charset) {
|
||||
|
||||
Assert.notNull(text, "the text resource must be non-null");
|
||||
Assert.notNull(charset, "the charset must be non-null");
|
||||
|
||||
try {
|
||||
this.userText = text.getContentAsString(charset);
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec user(Resource text) {
|
||||
Assert.notNull(text, "the text resource must be non-null");
|
||||
return this.user(text, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
public ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer) {
|
||||
Assert.notNull(consumer, "the consumer must be non-null");
|
||||
|
||||
var us = new DefaultPromptUserSpec();
|
||||
consumer.accept(us);
|
||||
this.userText = StringUtils.hasText(us.text()) ? us.text() : this.userText;
|
||||
this.userParams.putAll(us.params());
|
||||
this.media.addAll(us.media());
|
||||
return this;
|
||||
}
|
||||
|
||||
public CallResponseSpec call() {
|
||||
return new DefaultCallResponseSpec(chatModel, this);
|
||||
}
|
||||
|
||||
public StreamResponseSpec stream() {
|
||||
return new DefaultStreamResponseSpec(chatModel, this);
|
||||
}
|
||||
|
||||
public static DefaultChatClientRequestSpec adviseOnRequest(DefaultChatClientRequestSpec inputRequest,
|
||||
Map<String, Object> context) {
|
||||
|
||||
DefaultChatClientRequestSpec advisedRequest = inputRequest;
|
||||
|
||||
if (!CollectionUtils.isEmpty(inputRequest.advisors)) {
|
||||
AdvisedRequest adviseRequest = new AdvisedRequest(inputRequest.chatModel, inputRequest.userText,
|
||||
inputRequest.systemText, inputRequest.chatOptions, inputRequest.media,
|
||||
inputRequest.functionNames, inputRequest.functionCallbacks, inputRequest.messages,
|
||||
inputRequest.userParams, inputRequest.systemParams, inputRequest.advisors,
|
||||
inputRequest.advisorParams);
|
||||
|
||||
// apply the advisors onRequest
|
||||
var currentAdvisors = new ArrayList<>(inputRequest.advisors);
|
||||
for (RequestResponseAdvisor advisor : currentAdvisors) {
|
||||
adviseRequest = advisor.adviseRequest(adviseRequest, context);
|
||||
}
|
||||
|
||||
advisedRequest = new DefaultChatClientRequestSpec(adviseRequest.chatModel(), adviseRequest.userText(),
|
||||
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
|
||||
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
|
||||
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
|
||||
adviseRequest.advisorParams());
|
||||
}
|
||||
|
||||
return advisedRequest;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Prompt
|
||||
|
||||
public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
public DefaultCallPromptResponseSpec(ChatModel chatModel, Prompt prompt) {
|
||||
this.chatModel = chatModel;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public String content() {
|
||||
return doGetChatResponse(this.prompt).getResult().getOutput().getContent();
|
||||
}
|
||||
|
||||
public List<String> contents() {
|
||||
return doGetChatResponse(this.prompt).getResults().stream().map(r -> r.getOutput().getContent()).toList();
|
||||
}
|
||||
|
||||
public ChatResponse chatResponse() {
|
||||
return doGetChatResponse(this.prompt);
|
||||
}
|
||||
|
||||
private ChatResponse doGetChatResponse(Prompt prompt) {
|
||||
return chatModel.call(prompt);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultStreamPromptResponseSpec implements StreamPromptResponseSpec {
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
private final StreamingChatModel chatModel;
|
||||
|
||||
public DefaultStreamPromptResponseSpec(StreamingChatModel streamingChatModel, Prompt prompt) {
|
||||
this.chatModel = streamingChatModel;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public Flux<ChatResponse> chatResponse() {
|
||||
return doGetFluxChatResponse(this.prompt);
|
||||
}
|
||||
|
||||
private Flux<ChatResponse> doGetFluxChatResponse(Prompt prompt) {
|
||||
return this.chatModel.stream(prompt);
|
||||
}
|
||||
|
||||
public Flux<String> content() {
|
||||
return doGetFluxChatResponse(this.prompt).map(r -> {
|
||||
if (r.getResult() == null || r.getResult().getOutput() == null
|
||||
|| r.getResult().getOutput().getContent() == null) {
|
||||
return "";
|
||||
}
|
||||
return r.getResult().getOutput().getContent();
|
||||
}).filter(v -> StringUtils.hasText(v));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultChatClientPromptRequestSpec implements ChatClientPromptRequestSpec {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
public DefaultChatClientPromptRequestSpec(ChatModel chatModel, Prompt prompt) {
|
||||
this.chatModel = chatModel;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public CallPromptResponseSpec call() {
|
||||
return new DefaultCallPromptResponseSpec(this.chatModel, this.prompt);
|
||||
}
|
||||
|
||||
public StreamPromptResponseSpec stream() {
|
||||
return new DefaultStreamPromptResponseSpec((StreamingChatModel) this.chatModel, this.prompt);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* use the new fluid DSL starting in {@link #prompt()}
|
||||
* @param prompt the {@link Prompt prompt} object
|
||||
@@ -55,4 +813,4 @@ class DefaultChatClient implements ChatClient {
|
||||
return this.chatModel.call(prompt);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
/*
|
||||
* Copyright 2024-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient.Builder;
|
||||
import org.springframework.ai.chat.client.ChatClient.PromptSystemSpec;
|
||||
import org.springframework.ai.chat.client.ChatClient.PromptUserSpec;
|
||||
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* @author Mark Pollack
|
||||
* @author Christian Tzolov
|
||||
* @author Josh Long
|
||||
* @author Arjen Poutsma
|
||||
* @since 1.0.0
|
||||
*
|
||||
*/
|
||||
public class DefaultChatClientBuilder implements Builder {
|
||||
|
||||
protected final DefaultChatClientRequestSpec defaultRequest;
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
public DefaultChatClientBuilder(ChatModel chatModel) {
|
||||
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
|
||||
this.chatModel = chatModel;
|
||||
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, "", Map.of(), "", Map.of(), List.of(),
|
||||
List.of(), List.of(), List.of(), null, List.of(), Map.of());
|
||||
}
|
||||
|
||||
public ChatClient build() {
|
||||
return new DefaultChatClient(this.chatModel, this.defaultRequest);
|
||||
}
|
||||
|
||||
public Builder defaultAdvisors(RequestResponseAdvisor... advisor) {
|
||||
this.defaultRequest.advisors(advisor);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultAdvisors(Consumer<ChatClient.AdvisorSpec> advisorSpecConsumer) {
|
||||
this.defaultRequest.advisors(advisorSpecConsumer);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultAdvisors(List<RequestResponseAdvisor> advisors) {
|
||||
this.defaultRequest.advisors(advisors);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultOptions(ChatOptions chatOptions) {
|
||||
this.defaultRequest.options(chatOptions);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultUser(String text) {
|
||||
this.defaultRequest.user(text);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultUser(Resource text, Charset charset) {
|
||||
try {
|
||||
this.defaultRequest.user(text.getContentAsString(charset));
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultUser(Resource text) {
|
||||
return this.defaultUser(text, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
public Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer) {
|
||||
this.defaultRequest.user(userSpecConsumer);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultSystem(String text) {
|
||||
this.defaultRequest.system(text);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultSystem(Resource text, Charset charset) {
|
||||
try {
|
||||
this.defaultRequest.system(text.getContentAsString(charset));
|
||||
}
|
||||
catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultSystem(Resource text) {
|
||||
return this.defaultSystem(text, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
public Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer) {
|
||||
this.defaultRequest.system(systemSpecConsumer);
|
||||
return this;
|
||||
}
|
||||
|
||||
public <I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function) {
|
||||
this.defaultRequest.function(name, description, function);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultFunctions(String... functionNames) {
|
||||
this.defaultRequest.functions(functionNames);
|
||||
return this;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -20,7 +20,6 @@ import java.util.Map;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient.ChatClientRequest;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
@@ -31,7 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt;
|
||||
* chain of advisors with chared execution context.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @since 1.0.0 M1
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface RequestResponseAdvisor {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user