Remove deprecations in ChatClient and Advisors
* Update remaining Advisors and related classes to use the new APIs. * In AbstractChatMemoryAdvisor, the “doNextWithProtectFromBlockingBefore()” protected method has been changed from accepting AdvisedRequest to ChatClientRequest. It’s a breaking change since the alternative was not part of M8. * MessageAggregator has a new method to aggregate messages from ChatClientRequest. The previous method aggregating messages from AdvisedRequest has been removed. Warning since it wasn’t marked as deprecated in M8. * In SimpleLoggerAdvisor, the “requestToString” input argument needs to be updated to use ChatClientRequest. It’s a breaking change since the alternative was not part of M8. Same thing about the constructor. * The “getTemplateRenderer” method has been removed from BaseAdvisorChain. Each Advisor is encouraged to accept a PromptTemplate to achieve self-contained prompt augmentation operations. * Remove deprecations in ChatClient and Advisors, and update tests accordingly. * When building a Prompt from the ChatClient input, the SystemMessage passed via systemText() is placed first in the message list. Before, it was put last, resulting in errors with several model providers. Relates to gh-2655 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
This commit is contained in:
committed by
Ilayaperumal Gopinathan
parent
88490b3dfc
commit
4fe74d886e
@@ -21,17 +21,13 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponseStreamUtils;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.document.Document;
|
||||
@@ -53,7 +49,7 @@ import org.springframework.util.StringUtils;
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
|
||||
public class QuestionAnswerAdvisor implements BaseAdvisor {
|
||||
|
||||
public static final String RETRIEVED_DOCUMENTS = "qa_retrieved_documents";
|
||||
|
||||
@@ -80,100 +76,24 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
|
||||
|
||||
private final SearchRequest searchRequest;
|
||||
|
||||
private final boolean protectFromBlocking;
|
||||
private final Scheduler scheduler;
|
||||
|
||||
private final int order;
|
||||
|
||||
/**
|
||||
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
|
||||
* combines it with the user's text.
|
||||
* @param vectorStore The vector store to use
|
||||
*/
|
||||
public QuestionAnswerAdvisor(VectorStore vectorStore) {
|
||||
this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
|
||||
}
|
||||
|
||||
/**
|
||||
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
|
||||
* combines it with the user's text.
|
||||
* @param vectorStore The vector store to use
|
||||
* @param searchRequest The search request defined using the portable filter
|
||||
* expression syntax
|
||||
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
|
||||
*/
|
||||
@Deprecated
|
||||
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest) {
|
||||
this(vectorStore, searchRequest, DEFAULT_PROMPT_TEMPLATE, true, DEFAULT_ORDER);
|
||||
}
|
||||
|
||||
/**
|
||||
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
|
||||
* combines it with the user's text.
|
||||
* @param vectorStore The vector store to use
|
||||
* @param searchRequest The search request defined using the portable filter
|
||||
* expression syntax
|
||||
* @param userTextAdvise The user text to append to the existing user prompt. The text
|
||||
* should contain a placeholder named "question_answer_context".
|
||||
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
|
||||
*/
|
||||
@Deprecated
|
||||
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise) {
|
||||
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), true,
|
||||
this(vectorStore, SearchRequest.builder().build(), DEFAULT_PROMPT_TEMPLATE, BaseAdvisor.DEFAULT_SCHEDULER,
|
||||
DEFAULT_ORDER);
|
||||
}
|
||||
|
||||
/**
|
||||
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
|
||||
* combines it with the user's text.
|
||||
* @param vectorStore The vector store to use
|
||||
* @param searchRequest The search request defined using the portable filter
|
||||
* expression syntax
|
||||
* @param userTextAdvise The user text to append to the existing user prompt. The text
|
||||
* should contain a placeholder named "question_answer_context".
|
||||
* @param protectFromBlocking If true the advisor will protect the execution from
|
||||
* blocking threads. If false the advisor will not protect the execution from blocking
|
||||
* threads. This is useful when the advisor is used in a non-blocking environment. It
|
||||
* is true by default.
|
||||
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
|
||||
*/
|
||||
@Deprecated
|
||||
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
|
||||
boolean protectFromBlocking) {
|
||||
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
|
||||
DEFAULT_ORDER);
|
||||
}
|
||||
|
||||
/**
|
||||
* The QuestionAnswerAdvisor retrieves context information from a Vector Store and
|
||||
* combines it with the user's text.
|
||||
* @param vectorStore The vector store to use
|
||||
* @param searchRequest The search request defined using the portable filter
|
||||
* expression syntax
|
||||
* @param userTextAdvise The user text to append to the existing user prompt. The text
|
||||
* should contain a placeholder named "question_answer_context".
|
||||
* @param protectFromBlocking If true the advisor will protect the execution from
|
||||
* blocking threads. If false the advisor will not protect the execution from blocking
|
||||
* threads. This is useful when the advisor is used in a non-blocking environment. It
|
||||
* is true by default.
|
||||
* @param order The order of the advisor.
|
||||
* @deprecated in favor of the builder: {@link #builder(VectorStore)}
|
||||
*/
|
||||
@Deprecated
|
||||
public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise,
|
||||
boolean protectFromBlocking, int order) {
|
||||
this(vectorStore, searchRequest, PromptTemplate.builder().template(userTextAdvise).build(), protectFromBlocking,
|
||||
order);
|
||||
}
|
||||
|
||||
QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, @Nullable PromptTemplate promptTemplate,
|
||||
boolean protectFromBlocking, int order) {
|
||||
@Nullable Scheduler scheduler, int order) {
|
||||
Assert.notNull(vectorStore, "vectorStore cannot be null");
|
||||
Assert.notNull(searchRequest, "searchRequest cannot be null");
|
||||
|
||||
this.vectorStore = vectorStore;
|
||||
this.searchRequest = searchRequest;
|
||||
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
|
||||
this.protectFromBlocking = protectFromBlocking;
|
||||
this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
|
||||
this.order = order;
|
||||
}
|
||||
|
||||
@@ -181,97 +101,71 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
|
||||
return new Builder(vectorStore);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return this.order;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
|
||||
AdvisedRequest advisedRequest2 = before(advisedRequest);
|
||||
|
||||
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest2);
|
||||
|
||||
return after(advisedResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
|
||||
// This can be executed by both blocking and non-blocking Threads
|
||||
// E.g. a command line or Tomcat blocking Thread implementation
|
||||
// or by a WebFlux dispatch in a non-blocking manner.
|
||||
Flux<AdvisedResponse> advisedResponses = (this.protectFromBlocking) ?
|
||||
// @formatter:off
|
||||
Mono.just(advisedRequest)
|
||||
.publishOn(Schedulers.boundedElastic())
|
||||
.map(this::before)
|
||||
.flatMapMany(request -> chain.nextAroundStream(request))
|
||||
: chain.nextAroundStream(before(advisedRequest));
|
||||
// @formatter:on
|
||||
|
||||
return advisedResponses.map(ar -> {
|
||||
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
|
||||
ar = after(ar);
|
||||
}
|
||||
return ar;
|
||||
});
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
|
||||
var context = new HashMap<>(request.adviseContext());
|
||||
|
||||
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
|
||||
// 1. Search for similar documents in the vector store.
|
||||
var searchRequestToUse = SearchRequest.from(this.searchRequest)
|
||||
.query(request.userText())
|
||||
.filterExpression(doGetFilterExpression(context))
|
||||
.query(chatClientRequest.prompt().getUserMessage().getText())
|
||||
.filterExpression(doGetFilterExpression(chatClientRequest.context()))
|
||||
.build();
|
||||
|
||||
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);
|
||||
|
||||
// 2. Create the context from the documents.
|
||||
Map<String, Object> context = new HashMap<>(chatClientRequest.context());
|
||||
context.put(RETRIEVED_DOCUMENTS, documents);
|
||||
|
||||
String documentContext = documents.stream()
|
||||
.map(Document::getText)
|
||||
.collect(Collectors.joining(System.lineSeparator()));
|
||||
String documentContext = documents == null ? ""
|
||||
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
|
||||
|
||||
// 3. Augment the user prompt with the document context.
|
||||
String augmentedUserText = this.promptTemplate.mutate()
|
||||
.template(request.userText() + System.lineSeparator() + this.promptTemplate.getTemplate())
|
||||
.template(chatClientRequest.prompt().getUserMessage().getText() + System.lineSeparator()
|
||||
+ this.promptTemplate.getTemplate())
|
||||
.variables(Map.of("question_answer_context", documentContext))
|
||||
.build()
|
||||
.render();
|
||||
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
|
||||
.userText(augmentedUserText)
|
||||
.adviseContext(context)
|
||||
// 4. Update ChatClientRequest with augmented prompt.
|
||||
return chatClientRequest.mutate()
|
||||
.prompt(chatClientRequest.prompt().augmentUserMessage(augmentedUserText))
|
||||
.context(context)
|
||||
.build();
|
||||
|
||||
return advisedRequest;
|
||||
}
|
||||
|
||||
private AdvisedResponse after(AdvisedResponse advisedResponse) {
|
||||
ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
|
||||
chatResponseBuilder.metadata(RETRIEVED_DOCUMENTS, advisedResponse.adviseContext().get(RETRIEVED_DOCUMENTS));
|
||||
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
|
||||
@Override
|
||||
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
|
||||
ChatResponse.Builder chatResponseBuilder;
|
||||
if (chatClientResponse.chatResponse() == null) {
|
||||
chatResponseBuilder = ChatResponse.builder();
|
||||
}
|
||||
else {
|
||||
chatResponseBuilder = ChatResponse.builder().from(chatClientResponse.chatResponse());
|
||||
}
|
||||
chatResponseBuilder.metadata(RETRIEVED_DOCUMENTS, chatClientResponse.context().get(RETRIEVED_DOCUMENTS));
|
||||
return ChatClientResponse.builder()
|
||||
.chatResponse(chatResponseBuilder.build())
|
||||
.context(chatClientResponse.context())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Nullable
|
||||
protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {
|
||||
|
||||
if (!context.containsKey(FILTER_EXPRESSION)
|
||||
|| !StringUtils.hasText(context.get(FILTER_EXPRESSION).toString())) {
|
||||
return this.searchRequest.getFilterExpression();
|
||||
}
|
||||
return new FilterExpressionTextParser().parse(context.get(FILTER_EXPRESSION).toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Scheduler getScheduler() {
|
||||
return this.scheduler;
|
||||
}
|
||||
|
||||
public static final class Builder {
|
||||
@@ -282,7 +176,7 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
|
||||
|
||||
private PromptTemplate promptTemplate;
|
||||
|
||||
private boolean protectFromBlocking = true;
|
||||
private Scheduler scheduler;
|
||||
|
||||
private int order = DEFAULT_ORDER;
|
||||
|
||||
@@ -303,18 +197,13 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated in favour of {@link #promptTemplate(PromptTemplate)}
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder userTextAdvise(String userTextAdvise) {
|
||||
Assert.hasText(userTextAdvise, "The userTextAdvise must not be empty!");
|
||||
this.promptTemplate = PromptTemplate.builder().template(userTextAdvise).build();
|
||||
public Builder protectFromBlocking(boolean protectFromBlocking) {
|
||||
this.scheduler = protectFromBlocking ? BaseAdvisor.DEFAULT_SCHEDULER : Schedulers.immediate();
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder protectFromBlocking(boolean protectFromBlocking) {
|
||||
this.protectFromBlocking = protectFromBlocking;
|
||||
public Builder scheduler(Scheduler scheduler) {
|
||||
this.scheduler = scheduler;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -324,8 +213,8 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv
|
||||
}
|
||||
|
||||
public QuestionAnswerAdvisor build() {
|
||||
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate,
|
||||
this.protectFromBlocking, this.order);
|
||||
return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.promptTemplate, this.scheduler,
|
||||
this.order);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.vectorstore;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -24,19 +25,20 @@ import java.util.stream.Collectors;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.MessageAggregator;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* Memory is retrieved from a VectorStore added into the prompt's system text.
|
||||
@@ -87,80 +89,76 @@ public class VectorStoreChatMemoryAdvisor extends AbstractChatMemoryAdvisor<Vect
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
chatClientRequest = this.before(chatClientRequest);
|
||||
|
||||
advisedRequest = this.before(advisedRequest);
|
||||
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
|
||||
|
||||
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
|
||||
this.after(chatClientResponse);
|
||||
|
||||
this.observeAfter(advisedResponse);
|
||||
|
||||
return advisedResponse;
|
||||
return chatClientResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
|
||||
streamAdvisorChain, this::before);
|
||||
|
||||
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
|
||||
this::before);
|
||||
|
||||
// The observeAfter will certainly be executed on non-blocking Threads in case
|
||||
// of some models - e.g. when the model client is a WebClient
|
||||
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
|
||||
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
|
||||
String advisedSystemText;
|
||||
if (StringUtils.hasText(request.systemText())) {
|
||||
advisedSystemText = request.systemText() + System.lineSeparator() + this.systemTextAdvise;
|
||||
}
|
||||
else {
|
||||
advisedSystemText = this.systemTextAdvise;
|
||||
}
|
||||
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
|
||||
String conversationId = this.doGetConversationId(chatClientRequest.context());
|
||||
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
|
||||
|
||||
// 1. Retrieve the chat memory for the current conversation.
|
||||
var searchRequest = SearchRequest.builder()
|
||||
.query(request.userText())
|
||||
.topK(this.doGetChatMemoryRetrieveSize(request.adviseContext()))
|
||||
.filterExpression(
|
||||
DOCUMENT_METADATA_CONVERSATION_ID + "=='" + this.doGetConversationId(request.adviseContext()) + "'")
|
||||
.query(chatClientRequest.prompt().getUserMessage().getText())
|
||||
.topK(chatMemoryRetrieveSize)
|
||||
.filterExpression(DOCUMENT_METADATA_CONVERSATION_ID + "=='" + conversationId + "'")
|
||||
.build();
|
||||
|
||||
List<Document> documents = this.getChatMemoryStore().similaritySearch(searchRequest);
|
||||
|
||||
String longTermMemory = documents.stream()
|
||||
.map(Document::getText)
|
||||
.collect(Collectors.joining(System.lineSeparator()));
|
||||
// 2. Processed memory messages as a string.
|
||||
String longTermMemory = documents == null ? ""
|
||||
: documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
|
||||
|
||||
Map<String, Object> advisedSystemParams = new HashMap<>(request.systemParams());
|
||||
advisedSystemParams.put("long_term_memory", longTermMemory);
|
||||
// 2. Augment the system message.
|
||||
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
|
||||
String augmentedSystemText = PromptTemplate.builder()
|
||||
.template(systemMessage.getText() + System.lineSeparator() + this.systemTextAdvise)
|
||||
.variables(Map.of("long_term_memory", longTermMemory))
|
||||
.build()
|
||||
.render();
|
||||
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
|
||||
.systemText(advisedSystemText)
|
||||
.systemParams(advisedSystemParams)
|
||||
// 3. Create a new request with the augmented system message.
|
||||
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
|
||||
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.builder().text(request.userText()).media(request.media()).build();
|
||||
this.getChatMemoryStore()
|
||||
.write(toDocuments(List.of(userMessage), this.doGetConversationId(request.adviseContext())));
|
||||
// 4. Add the new user message to the conversation memory.
|
||||
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
|
||||
this.getChatMemoryStore().write(toDocuments(List.of(userMessage), conversationId));
|
||||
|
||||
return advisedRequest;
|
||||
return processedChatClientRequest;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
|
||||
List<Message> assistantMessages = advisedResponse.response()
|
||||
.getResults()
|
||||
.stream()
|
||||
.map(g -> (Message) g.getOutput())
|
||||
.toList();
|
||||
|
||||
private void after(ChatClientResponse chatClientResponse) {
|
||||
List<Message> assistantMessages = new ArrayList<>();
|
||||
if (chatClientResponse.chatResponse() != null) {
|
||||
assistantMessages = chatClientResponse.chatResponse()
|
||||
.getResults()
|
||||
.stream()
|
||||
.map(g -> (Message) g.getOutput())
|
||||
.toList();
|
||||
}
|
||||
this.getChatMemoryStore()
|
||||
.write(toDocuments(assistantMessages, this.doGetConversationId(advisedResponse.adviseContext())));
|
||||
.write(toDocuments(assistantMessages, this.doGetConversationId(chatClientResponse.context())));
|
||||
}
|
||||
|
||||
private List<Document> toDocuments(List<Message> messages, String conversationId) {
|
||||
|
||||
List<Document> docs = messages.stream()
|
||||
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
|
||||
.map(message -> {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -52,6 +52,7 @@ import static org.mockito.BDDMockito.given;
|
||||
* @author Christian Tzolov
|
||||
* @author Timo Salm
|
||||
* @author Alexandros Pappas
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class QuestionAnswerAdvisorTests {
|
||||
@@ -112,8 +113,9 @@ public class QuestionAnswerAdvisorTests {
|
||||
given(this.vectorStore.similaritySearch(this.vectorSearchCaptor.capture()))
|
||||
.willReturn(List.of(new Document("doc1"), new Document("doc2")));
|
||||
|
||||
var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore,
|
||||
SearchRequest.builder().similarityThreshold(0.99d).topK(6).build());
|
||||
var qaAdvisor = QuestionAnswerAdvisor.builder(this.vectorStore)
|
||||
.searchRequest(SearchRequest.builder().similarityThreshold(0.99d).topK(6).build())
|
||||
.build();
|
||||
|
||||
var chatClient = ChatClient.builder(this.chatModel)
|
||||
.defaultSystem("Default system text.")
|
||||
@@ -187,7 +189,9 @@ public class QuestionAnswerAdvisorTests {
|
||||
.willReturn(List.of(new Document("doc1"), new Document("doc2")));
|
||||
|
||||
var chatClient = ChatClient.builder(this.chatModel).build();
|
||||
var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.builder().build());
|
||||
var qaAdvisor = QuestionAnswerAdvisor.builder(this.vectorStore)
|
||||
.searchRequest(SearchRequest.builder().build())
|
||||
.build();
|
||||
|
||||
var userTextTemplate = "Please answer my question {question}";
|
||||
// @formatter:off
|
||||
@@ -215,10 +219,15 @@ public class QuestionAnswerAdvisorTests {
|
||||
.willReturn(List.of(new Document("doc1"), new Document("doc2")));
|
||||
|
||||
var chatClient = ChatClient.builder(this.chatModel).build();
|
||||
var qaAdvisor = new QuestionAnswerAdvisor(this.vectorStore, SearchRequest.builder().build());
|
||||
var qaAdvisor = QuestionAnswerAdvisor.builder(this.vectorStore)
|
||||
.searchRequest(SearchRequest.builder().build())
|
||||
.build();
|
||||
|
||||
var userTextTemplate = "Please answer my question {question}";
|
||||
var userPromptTemplate = PromptTemplate.builder().template(userTextTemplate).variables(Map.of("question", "XYZ")).build();
|
||||
var userPromptTemplate = PromptTemplate.builder()
|
||||
.template(userTextTemplate)
|
||||
.variables(Map.of("question", "XYZ"))
|
||||
.build();
|
||||
var userMessage = userPromptTemplate.createMessage();
|
||||
// @formatter:off
|
||||
chatClient.prompt(new Prompt(userMessage))
|
||||
|
||||
@@ -22,7 +22,6 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.ChatClientCustomizer;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationFilter;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
@@ -80,20 +79,6 @@ public class ChatClientAutoConfiguration {
|
||||
return chatClientBuilderConfigurer.configure(builder);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated in favour of {@link #chatClientPromptContentObservationFilter()}.
|
||||
*/
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnProperty(prefix = ChatClientBuilderProperties.CONFIG_PREFIX + ".observations", name = "include-input",
|
||||
havingValue = "true")
|
||||
@Deprecated
|
||||
ChatClientInputContentObservationFilter chatClientInputContentObservationFilter() {
|
||||
logger.warn(
|
||||
"You have enabled the inclusion of the input content in the observations, with the risk of exposing sensitive or private information. Please, be careful!");
|
||||
return new ChatClientInputContentObservationFilter();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnMissingBean
|
||||
@ConditionalOnProperty(prefix = ChatClientBuilderProperties.CONFIG_PREFIX + ".observations",
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
package org.springframework.ai.model.chat.client.autoconfigure;
|
||||
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.DeprecatedConfigurationProperty;
|
||||
|
||||
/**
|
||||
* Configuration properties for the chat client builder.
|
||||
@@ -55,27 +54,11 @@ public class ChatClientBuilderProperties {
|
||||
|
||||
public static class Observations {
|
||||
|
||||
/**
|
||||
* Whether to include the input content in the observations.
|
||||
* @deprecated Use {@link #includePrompt} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
private boolean includeInput = false;
|
||||
|
||||
/**
|
||||
* Whether to include the prompt content in the observations.
|
||||
*/
|
||||
private boolean includePrompt = false;
|
||||
|
||||
@DeprecatedConfigurationProperty(replacement = "spring.ai.chat.observations.include-prompt")
|
||||
public boolean isIncludeInput() {
|
||||
return this.includeInput;
|
||||
}
|
||||
|
||||
public void setIncludeInput(boolean includeCompletion) {
|
||||
this.includeInput = includeCompletion;
|
||||
}
|
||||
|
||||
public boolean isIncludePrompt() {
|
||||
return this.includePrompt;
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ package org.springframework.ai.model.chat.client.autoconfigure;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.observation.ChatClientInputContentObservationFilter;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientPromptContentObservationFilter;
|
||||
import org.springframework.boot.autoconfigure.AutoConfigurations;
|
||||
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
|
||||
@@ -36,18 +35,6 @@ class ChatClientObservationAutoConfigurationTests {
|
||||
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
|
||||
.withConfiguration(AutoConfigurations.of(ChatClientAutoConfiguration.class));
|
||||
|
||||
@Test
|
||||
void inputContentFilterDefault() {
|
||||
this.contextRunner
|
||||
.run(context -> assertThat(context).doesNotHaveBean(ChatClientInputContentObservationFilter.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void inputContentFilterEnabled() {
|
||||
this.contextRunner.withPropertyValues("spring.ai.chat.client.observations.include-input=true")
|
||||
.run(context -> assertThat(context).hasSingleBean(ChatClientInputContentObservationFilter.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
void promptContentFilterDefault() {
|
||||
this.contextRunner
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -28,13 +28,10 @@ import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
@@ -64,6 +61,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*")
|
||||
@@ -81,7 +79,7 @@ public class OpenAiPaymentTransactionIT {
|
||||
@ValueSource(strings = { "paymentStatus", "paymentStatuses" })
|
||||
public void transactionPaymentStatuses(String functionName) {
|
||||
List<TransactionStatusResponse> content = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.toolNames(functionName)
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
@@ -112,7 +110,7 @@ public class OpenAiPaymentTransactionIT {
|
||||
});
|
||||
|
||||
Flux<String> flux = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.toolNames(functionName)
|
||||
.user(u -> u.text("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
@@ -141,49 +139,6 @@ public class OpenAiPaymentTransactionIT {
|
||||
|
||||
}
|
||||
|
||||
private static class LoggingAdvisor implements CallAroundAdvisor {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
|
||||
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
|
||||
advisedRequest = this.before(advisedRequest);
|
||||
|
||||
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
|
||||
|
||||
this.observeAfter(advisedResponse);
|
||||
|
||||
return advisedResponse;
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
logger.info("System text: \n" + request.systemText());
|
||||
logger.info("System params: " + request.systemParams());
|
||||
logger.info("User text: \n" + request.userText());
|
||||
logger.info("User params:" + request.userParams());
|
||||
logger.info("Function names: " + request.toolNames());
|
||||
|
||||
logger.info("Options: " + request.chatOptions().toString());
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
logger.info("Response: " + advisedResponse.response());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
record Transaction(String id) {
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,17 +16,13 @@
|
||||
|
||||
package org.springframework.ai.openai.chat.client;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
|
||||
/**
|
||||
* Drawing inspiration from the human strategy of re-reading, this advisor implements a
|
||||
@@ -36,9 +32,10 @@ import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
* Language Models</a>
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
|
||||
public class ReReadingAdvisor implements BaseAdvisor {
|
||||
|
||||
private static final String DEFAULT_RE2_ADVISE_TEMPLATE = """
|
||||
{re2_input_query}
|
||||
@@ -57,29 +54,22 @@ public class ReReadingAdvisor implements CallAroundAdvisor, StreamAroundAdvisor
|
||||
this.re2AdviseTemplate = re2AdviseTemplate;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
@Override
|
||||
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
|
||||
String augmentedUserText = PromptTemplate.builder()
|
||||
.template(re2AdviseTemplate)
|
||||
.variables(Map.of("re2_input_query", chatClientRequest.prompt().getUserMessage().getText()))
|
||||
.build()
|
||||
.render();
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest advisedRequest) {
|
||||
|
||||
Map<String, Object> advisedUserParams = new HashMap<>(advisedRequest.userParams());
|
||||
advisedUserParams.put("re2_input_query", advisedRequest.userText());
|
||||
|
||||
return AdvisedRequest.from(advisedRequest)
|
||||
.userText(this.re2AdviseTemplate)
|
||||
.userParams(advisedUserParams)
|
||||
return chatClientRequest.mutate()
|
||||
.prompt(chatClientRequest.prompt().augmentUserMessage(augmentedUserText))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
return chain.nextAroundCall(this.before(advisedRequest));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
return chain.nextAroundStream(this.before(advisedRequest));
|
||||
public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
|
||||
return chatClientResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -29,13 +29,10 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
|
||||
@@ -57,6 +54,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
|
||||
@@ -75,7 +73,7 @@ public class VertexAiGeminiPaymentTransactionIT {
|
||||
public void paymentStatuses() {
|
||||
// @formatter:off
|
||||
String content = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.toolNames("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
@@ -92,7 +90,7 @@ public class VertexAiGeminiPaymentTransactionIT {
|
||||
public void streamingPaymentStatuses() {
|
||||
|
||||
Flux<String> streamContent = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.toolNames("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
@@ -120,45 +118,6 @@ public class VertexAiGeminiPaymentTransactionIT {
|
||||
|
||||
}
|
||||
|
||||
private static class LoggingAdvisor implements CallAroundAdvisor {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
var response = chain.nextAroundCall(before(advisedRequest));
|
||||
observeAfter(response);
|
||||
return response;
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
logger.info("System text: \n" + request.systemText());
|
||||
logger.info("System params: " + request.systemParams());
|
||||
logger.info("User text: \n" + request.userText());
|
||||
logger.info("User params:" + request.userParams());
|
||||
logger.info("Function names: " + request.toolNames());
|
||||
|
||||
logger.info("Options: " + request.chatOptions().toString());
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
logger.info("Response: " + advisedResponse.response());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
record Transaction(String id) {
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -29,13 +29,10 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
@@ -59,6 +56,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
|
||||
@@ -76,10 +74,15 @@ public class VertexAiGeminiPaymentTransactionMethodIT {
|
||||
@Test
|
||||
public void paymentStatuses() {
|
||||
|
||||
String content = this.chatClient.prompt().advisors(new LoggingAdvisor()).toolNames("paymentStatus").user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
If requred invoke the function per transaction.
|
||||
""").call().content();
|
||||
String content = this.chatClient.prompt()
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.toolNames("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
If requred invoke the function per transaction.
|
||||
""")
|
||||
.call()
|
||||
.content();
|
||||
logger.info("" + content);
|
||||
|
||||
assertThat(content).contains("001", "002", "003");
|
||||
@@ -90,7 +93,7 @@ public class VertexAiGeminiPaymentTransactionMethodIT {
|
||||
public void streamingPaymentStatuses() {
|
||||
|
||||
Flux<String> streamContent = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.toolNames("paymentStatus")
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
@@ -118,45 +121,6 @@ public class VertexAiGeminiPaymentTransactionMethodIT {
|
||||
|
||||
}
|
||||
|
||||
private static class LoggingAdvisor implements CallAroundAdvisor {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
var response = chain.nextAroundCall(before(advisedRequest));
|
||||
observeAfter(response);
|
||||
return response;
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
logger.info("System text: \n" + request.systemText());
|
||||
logger.info("System params: " + request.systemParams());
|
||||
logger.info("User text: \n" + request.userText());
|
||||
logger.info("User params:" + request.userParams());
|
||||
logger.info("Function names: " + request.toolNames());
|
||||
|
||||
logger.info("Options: " + request.chatOptions().toString());
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
logger.info("Response: " + advisedResponse.response());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
record Transaction(String id) {
|
||||
}
|
||||
|
||||
|
||||
@@ -28,13 +28,10 @@ import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.model.tool.ToolCallingManager;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.annotation.Tool;
|
||||
@@ -56,6 +53,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@SpringBootTest
|
||||
@EnabledIfEnvironmentVariable(named = "VERTEX_AI_GEMINI_PROJECT_ID", matches = ".*")
|
||||
@@ -74,7 +72,7 @@ public class VertexAiGeminiPaymentTransactionToolsIT {
|
||||
public void paymentStatuses() {
|
||||
// @formatter:off
|
||||
String content = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.tools(new MyTools())
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
@@ -91,7 +89,7 @@ public class VertexAiGeminiPaymentTransactionToolsIT {
|
||||
public void streamingPaymentStatuses() {
|
||||
|
||||
Flux<String> streamContent = this.chatClient.prompt()
|
||||
.advisors(new LoggingAdvisor())
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.tools(new MyTools())
|
||||
.user("""
|
||||
What is the status of my payment transactions 001, 002 and 003?
|
||||
@@ -119,45 +117,6 @@ public class VertexAiGeminiPaymentTransactionToolsIT {
|
||||
|
||||
}
|
||||
|
||||
private static class LoggingAdvisor implements CallAroundAdvisor {
|
||||
|
||||
private final Logger logger = LoggerFactory.getLogger(LoggingAdvisor.class);
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
var response = chain.nextAroundCall(before(advisedRequest));
|
||||
observeAfter(response);
|
||||
return response;
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
logger.info("System text: \n" + request.systemText());
|
||||
logger.info("System params: " + request.systemParams());
|
||||
logger.info("User text: \n" + request.userText());
|
||||
logger.info("User params:" + request.userParams());
|
||||
logger.info("Function names: " + request.toolNames());
|
||||
|
||||
logger.info("Options: " + request.chatOptions().toString());
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
logger.info("Response: " + advisedResponse.response());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
record Transaction(String id) {
|
||||
}
|
||||
|
||||
|
||||
@@ -26,15 +26,7 @@ public enum ChatClientAttributes {
|
||||
|
||||
//@formatter:off
|
||||
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
ADVISORS("spring.ai.chat.client.advisors"),
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
CHAT_MODEL("spring.ai.chat.client.model"),
|
||||
OUTPUT_FORMAT("spring.ai.chat.client.output.format"),
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
USER_PARAMS("spring.ai.chat.client.user.params"),
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
SYSTEM_PARAMS("spring.ai.chat.client.system.params");
|
||||
OUTPUT_FORMAT("spring.ai.chat.client.output.format");
|
||||
|
||||
//@formatter:on
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ import java.net.URL;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@@ -38,7 +38,6 @@ import reactor.core.publisher.Flux;
|
||||
import org.springframework.ai.chat.client.advisor.ChatModelCallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.ChatModelStreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
|
||||
@@ -47,16 +46,18 @@ import org.springframework.ai.chat.client.observation.ChatClientObservationDocum
|
||||
import org.springframework.ai.chat.client.observation.DefaultChatClientObservationConvention;
|
||||
import org.springframework.ai.chat.messages.AbstractMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.StructuredOutputConverter;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.template.TemplateRenderer;
|
||||
import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
@@ -64,7 +65,6 @@ import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.ai.tool.ToolCallbacks;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.lang.NonNull;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -97,52 +97,6 @@ public class DefaultChatClient implements ChatClient {
|
||||
this.defaultChatClientRequest = defaultChatClientRequest;
|
||||
}
|
||||
|
||||
private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) {
|
||||
Assert.notNull(inputRequest, "inputRequest cannot be null");
|
||||
|
||||
Map<String, Object> advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams());
|
||||
|
||||
// Process userText, media and messages before creating the AdvisedRequest.
|
||||
String userText = inputRequest.userText;
|
||||
List<Media> media = inputRequest.media;
|
||||
List<Message> messages = inputRequest.messages;
|
||||
|
||||
// If the userText is empty, then try extracting the userText from the last
|
||||
// message
|
||||
// in the messages list and remove it from the messages list.
|
||||
if (!StringUtils.hasText(userText) && !CollectionUtils.isEmpty(messages)) {
|
||||
Message lastMessage = messages.get(messages.size() - 1);
|
||||
if (lastMessage.getMessageType() == MessageType.USER) {
|
||||
UserMessage userMessage = (UserMessage) lastMessage;
|
||||
if (StringUtils.hasText(userMessage.getText())) {
|
||||
userText = lastMessage.getText();
|
||||
}
|
||||
Collection<Media> messageMedia = userMessage.getMedia();
|
||||
if (!CollectionUtils.isEmpty(messageMedia)) {
|
||||
media.addAll(messageMedia);
|
||||
}
|
||||
messages = messages.subList(0, messages.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions,
|
||||
media, inputRequest.toolNames, inputRequest.toolCallbacks, messages, inputRequest.userParams,
|
||||
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext,
|
||||
inputRequest.toolContext);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest,
|
||||
ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) {
|
||||
|
||||
return new DefaultChatClientRequestSpec(advisedRequest.chatModel(), advisedRequest.userText(),
|
||||
advisedRequest.userParams(), advisedRequest.systemText(), advisedRequest.systemParams(),
|
||||
advisedRequest.toolCallbacks(), advisedRequest.messages(), advisedRequest.toolNames(),
|
||||
advisedRequest.media(), advisedRequest.chatOptions(), advisedRequest.advisors(),
|
||||
advisedRequest.advisorParams(), observationRegistry, customObservationConvention,
|
||||
advisedRequest.toolContext(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientRequestSpec prompt() {
|
||||
return new DefaultChatClientRequestSpec(this.defaultChatClientRequest);
|
||||
@@ -510,7 +464,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
.request(chatClientRequest)
|
||||
.advisors(advisorChain.getCallAdvisors())
|
||||
.stream(false)
|
||||
.withFormat(outputFormat)
|
||||
.format(outputFormat)
|
||||
.build();
|
||||
|
||||
var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(observationConvention,
|
||||
@@ -696,14 +650,6 @@ public class DefaultChatClient implements ChatClient {
|
||||
this.templateRenderer = templateRenderer != null ? templateRenderer : DEFAULT_TEMPLATE_RENDERER;
|
||||
}
|
||||
|
||||
private ObservationRegistry getObservationRegistry() {
|
||||
return this.observationRegistry;
|
||||
}
|
||||
|
||||
private ChatClientObservationConvention getCustomObservationConvention() {
|
||||
return this.observationConvention;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public String getUserText() {
|
||||
return this.userText;
|
||||
@@ -755,6 +701,10 @@ public class DefaultChatClient implements ChatClient {
|
||||
return this.toolContext;
|
||||
}
|
||||
|
||||
public TemplateRenderer getTemplateRenderer() {
|
||||
return this.templateRenderer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose
|
||||
* settings are replicated from this {@code ChatClientRequest}.
|
||||
@@ -762,6 +712,9 @@ public class DefaultChatClient implements ChatClient {
|
||||
public Builder mutate() {
|
||||
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient
|
||||
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
|
||||
.defaultTemplateRenderer(this.templateRenderer)
|
||||
.defaultToolCallbacks(this.toolCallbacks)
|
||||
.defaultToolContext(this.toolContext)
|
||||
.defaultToolNames(StringUtils.toStringArray(this.toolNames));
|
||||
|
||||
if (StringUtils.hasText(this.userText)) {
|
||||
@@ -778,8 +731,6 @@ public class DefaultChatClient implements ChatClient {
|
||||
}
|
||||
|
||||
builder.addMessages(this.messages);
|
||||
builder.addToolCallbacks(this.toolCallbacks);
|
||||
builder.addToolContext(this.toolContext);
|
||||
|
||||
return builder;
|
||||
}
|
||||
@@ -955,14 +906,14 @@ public class DefaultChatClient implements ChatClient {
|
||||
|
||||
public CallResponseSpec call() {
|
||||
BaseAdvisorChain advisorChain = buildAdvisorChain();
|
||||
return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(this.templateRenderer),
|
||||
advisorChain, observationRegistry, observationConvention);
|
||||
return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
|
||||
observationRegistry, observationConvention);
|
||||
}
|
||||
|
||||
public StreamResponseSpec stream() {
|
||||
BaseAdvisorChain advisorChain = buildAdvisorChain();
|
||||
return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(this.templateRenderer),
|
||||
advisorChain, observationRegistry, observationConvention);
|
||||
return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
|
||||
observationRegistry, observationConvention);
|
||||
}
|
||||
|
||||
private BaseAdvisorChain buildAdvisorChain() {
|
||||
|
||||
@@ -36,7 +36,6 @@ import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.template.TemplateRenderer;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
@@ -180,13 +179,6 @@ public class DefaultChatClientBuilder implements Builder {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated // Use defaultTools()
|
||||
public <I, O> Builder defaultFunction(String name, String description, java.util.function.Function<I, O> function) {
|
||||
this.defaultRequest
|
||||
.toolCallbacks(FunctionToolCallback.builder(name, function).description(description).build());
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder defaultToolContext(Map<String, Object> toolContext) {
|
||||
this.defaultRequest.toolContext(toolContext);
|
||||
return this;
|
||||
@@ -202,13 +194,4 @@ public class DefaultChatClientBuilder implements Builder {
|
||||
this.defaultRequest.messages(messages);
|
||||
}
|
||||
|
||||
void addToolCallbacks(List<ToolCallback> toolCallbacks) {
|
||||
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
this.defaultRequest.toolCallbacks(toolCallbacks.toArray(ToolCallback[]::new));
|
||||
}
|
||||
|
||||
void addToolContext(Map<String, Object> toolContext) {
|
||||
this.defaultRequest.toolContext(toolContext);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* Utilities for supporting the {@link DefaultChatClient} implementation.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
class DefaultChatClientUtils {
|
||||
|
||||
static ChatClientRequest toChatClientRequest(DefaultChatClient.DefaultChatClientRequestSpec inputRequest) {
|
||||
Assert.notNull(inputRequest, "inputRequest cannot be null");
|
||||
|
||||
/*
|
||||
* ==========* MESSAGES * ==========
|
||||
*/
|
||||
|
||||
List<Message> processedMessages = new ArrayList<>();
|
||||
|
||||
// System Text => First in the list
|
||||
String processedSystemText = inputRequest.getSystemText();
|
||||
if (StringUtils.hasText(processedSystemText)) {
|
||||
if (!CollectionUtils.isEmpty(inputRequest.getSystemParams())) {
|
||||
processedSystemText = PromptTemplate.builder()
|
||||
.template(processedSystemText)
|
||||
.variables(inputRequest.getSystemParams())
|
||||
.renderer(inputRequest.getTemplateRenderer())
|
||||
.build()
|
||||
.render();
|
||||
}
|
||||
processedMessages.add(new SystemMessage(processedSystemText));
|
||||
}
|
||||
|
||||
// Messages => In the middle of the list
|
||||
if (!CollectionUtils.isEmpty(inputRequest.getMessages())) {
|
||||
processedMessages.addAll(inputRequest.getMessages());
|
||||
}
|
||||
|
||||
// User Test => Last in the list
|
||||
String processedUserText = inputRequest.getUserText();
|
||||
if (StringUtils.hasText(processedUserText)) {
|
||||
if (!CollectionUtils.isEmpty(inputRequest.getUserParams())) {
|
||||
processedUserText = PromptTemplate.builder()
|
||||
.template(processedUserText)
|
||||
.variables(inputRequest.getUserParams())
|
||||
.renderer(inputRequest.getTemplateRenderer())
|
||||
.build()
|
||||
.render();
|
||||
}
|
||||
processedMessages.add(UserMessage.builder().text(processedUserText).media(inputRequest.getMedia()).build());
|
||||
}
|
||||
|
||||
/*
|
||||
* ==========* OPTIONS * ==========
|
||||
*/
|
||||
|
||||
ChatOptions processedChatOptions = inputRequest.getChatOptions();
|
||||
if (processedChatOptions instanceof ToolCallingChatOptions toolCallingChatOptions) {
|
||||
if (!inputRequest.getToolNames().isEmpty()) {
|
||||
Set<String> toolNames = ToolCallingChatOptions
|
||||
.mergeToolNames(new HashSet<>(inputRequest.getToolNames()), toolCallingChatOptions.getToolNames());
|
||||
toolCallingChatOptions.setToolNames(toolNames);
|
||||
}
|
||||
if (!inputRequest.getToolCallbacks().isEmpty()) {
|
||||
List<ToolCallback> toolCallbacks = ToolCallingChatOptions
|
||||
.mergeToolCallbacks(inputRequest.getToolCallbacks(), toolCallingChatOptions.getToolCallbacks());
|
||||
ToolCallingChatOptions.validateToolCallbacks(toolCallbacks);
|
||||
toolCallingChatOptions.setToolCallbacks(toolCallbacks);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(inputRequest.getToolContext())) {
|
||||
Map<String, Object> toolContext = ToolCallingChatOptions.mergeToolContext(inputRequest.getToolContext(),
|
||||
toolCallingChatOptions.getToolContext());
|
||||
toolCallingChatOptions.setToolContext(toolContext);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* ==========* REQUEST * ==========
|
||||
*/
|
||||
|
||||
return ChatClientRequest.builder()
|
||||
.prompt(Prompt.builder().messages(processedMessages).chatOptions(processedChatOptions).build())
|
||||
.context(new ConcurrentHashMap<>(inputRequest.getAdvisorParams()))
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -19,17 +19,17 @@ package org.springframework.ai.chat.client.advisor;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
@@ -38,9 +38,10 @@ import org.springframework.util.Assert;
|
||||
* @param <T> the type of the chat memory.
|
||||
* @author Christian Tzolov
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor, StreamAroundAdvisor {
|
||||
public abstract class AbstractChatMemoryAdvisor<T> implements CallAdvisor, StreamAdvisor {
|
||||
|
||||
/**
|
||||
* The key to retrieve the chat memory conversation id from the context.
|
||||
@@ -176,26 +177,18 @@ public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor,
|
||||
: this.defaultChatMemoryRetrieveSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the next advisor in the chain.
|
||||
* @param advisedRequest the advised request
|
||||
* @param chain the advisor chain
|
||||
* @param beforeAdvise the before advise function
|
||||
* @return the advised response
|
||||
*/
|
||||
protected Flux<AdvisedResponse> doNextWithProtectFromBlockingBefore(AdvisedRequest advisedRequest,
|
||||
StreamAroundAdvisorChain chain, Function<AdvisedRequest, AdvisedRequest> beforeAdvise) {
|
||||
|
||||
protected Flux<ChatClientResponse> doNextWithProtectFromBlockingBefore(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain, Function<ChatClientRequest, ChatClientRequest> before) {
|
||||
// This can be executed by both blocking and non-blocking Threads
|
||||
// E.g. a command line or Tomcat blocking Thread implementation
|
||||
// or by a WebFlux dispatch in a non-blocking manner.
|
||||
return (this.protectFromBlocking) ?
|
||||
// @formatter:off
|
||||
Mono.just(advisedRequest)
|
||||
.publishOn(Schedulers.boundedElastic())
|
||||
.map(beforeAdvise)
|
||||
.flatMapMany(request -> chain.nextAroundStream(request))
|
||||
: chain.nextAroundStream(beforeAdvise.apply(advisedRequest));
|
||||
Mono.just(chatClientRequest)
|
||||
.publishOn(Schedulers.boundedElastic())
|
||||
.map(before)
|
||||
.flatMapMany(streamAdvisorChain::nextStream)
|
||||
: streamAdvisorChain.nextStream(before.apply(chatClientRequest));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -242,7 +235,7 @@ public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor,
|
||||
* @param conversationId the conversation id
|
||||
* @return the builder
|
||||
*/
|
||||
public AbstractBuilder conversationId(String conversationId) {
|
||||
public AbstractBuilder<T> conversationId(String conversationId) {
|
||||
this.conversationId = conversationId;
|
||||
return this;
|
||||
}
|
||||
@@ -252,7 +245,7 @@ public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor,
|
||||
* @param chatMemoryRetrieveSize the chat memory retrieve size
|
||||
* @return the builder
|
||||
*/
|
||||
public AbstractBuilder chatMemoryRetrieveSize(int chatMemoryRetrieveSize) {
|
||||
public AbstractBuilder<T> chatMemoryRetrieveSize(int chatMemoryRetrieveSize) {
|
||||
this.chatMemoryRetrieveSize = chatMemoryRetrieveSize;
|
||||
return this;
|
||||
}
|
||||
@@ -262,7 +255,7 @@ public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor,
|
||||
* @param protectFromBlocking whether to protect from blocking
|
||||
* @return the builder
|
||||
*/
|
||||
public AbstractBuilder protectFromBlocking(boolean protectFromBlocking) {
|
||||
public AbstractBuilder<T> protectFromBlocking(boolean protectFromBlocking) {
|
||||
this.protectFromBlocking = protectFromBlocking;
|
||||
return this;
|
||||
}
|
||||
@@ -272,7 +265,7 @@ public abstract class AbstractChatMemoryAdvisor<T> implements CallAroundAdvisor,
|
||||
* @param order the order
|
||||
* @return the builder
|
||||
*/
|
||||
public AbstractBuilder order(int order) {
|
||||
public AbstractBuilder<T> order(int order) {
|
||||
this.order = order;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
@@ -46,7 +46,7 @@ public final class ChatModelCallAdvisor implements CallAdvisor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) {
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
ChatClientRequest formattedChatClientRequest = augmentWithFormatInstructions(chatClientRequest);
|
||||
|
||||
@@ -43,7 +43,8 @@ public final class ChatModelStreamAdvisor implements StreamAdvisor {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain) {
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
return chatModel.stream(chatClientRequest.prompt())
|
||||
|
||||
@@ -26,13 +26,9 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccess
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.template.TemplateRenderer;
|
||||
import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
import org.springframework.lang.Nullable;
|
||||
@@ -62,51 +58,46 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
|
||||
private static final TemplateRenderer DEFAULT_TEMPLATE_RENDERER = StTemplateRenderer.builder().build();
|
||||
|
||||
private final List<CallAroundAdvisor> originalCallAdvisors;
|
||||
private final List<CallAdvisor> originalCallAdvisors;
|
||||
|
||||
private final List<StreamAroundAdvisor> originalStreamAdvisors;
|
||||
private final List<StreamAdvisor> originalStreamAdvisors;
|
||||
|
||||
private final Deque<CallAroundAdvisor> callAroundAdvisors;
|
||||
private final Deque<CallAdvisor> callAdvisors;
|
||||
|
||||
private final Deque<StreamAroundAdvisor> streamAroundAdvisors;
|
||||
private final Deque<StreamAdvisor> streamAdvisors;
|
||||
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
private final TemplateRenderer templateRenderer;
|
||||
|
||||
DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, @Nullable TemplateRenderer templateRenderer,
|
||||
Deque<CallAroundAdvisor> callAroundAdvisors, Deque<StreamAroundAdvisor> streamAroundAdvisors) {
|
||||
Deque<CallAdvisor> callAdvisors, Deque<StreamAdvisor> streamAdvisors) {
|
||||
|
||||
Assert.notNull(observationRegistry, "the observationRegistry must be non-null");
|
||||
Assert.notNull(callAroundAdvisors, "the callAroundAdvisors must be non-null");
|
||||
Assert.notNull(streamAroundAdvisors, "the streamAroundAdvisors must be non-null");
|
||||
Assert.notNull(callAdvisors, "the callAdvisors must be non-null");
|
||||
Assert.notNull(streamAdvisors, "the streamAdvisors must be non-null");
|
||||
|
||||
this.observationRegistry = observationRegistry;
|
||||
this.templateRenderer = templateRenderer != null ? templateRenderer : DEFAULT_TEMPLATE_RENDERER;
|
||||
this.callAroundAdvisors = callAroundAdvisors;
|
||||
this.streamAroundAdvisors = streamAroundAdvisors;
|
||||
this.originalCallAdvisors = List.copyOf(callAroundAdvisors);
|
||||
this.originalStreamAdvisors = List.copyOf(streamAroundAdvisors);
|
||||
this.callAdvisors = callAdvisors;
|
||||
this.streamAdvisors = streamAdvisors;
|
||||
this.originalCallAdvisors = List.copyOf(callAdvisors);
|
||||
this.originalStreamAdvisors = List.copyOf(streamAdvisors);
|
||||
}
|
||||
|
||||
public static Builder builder(ObservationRegistry observationRegistry) {
|
||||
return new Builder(observationRegistry);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TemplateRenderer getTemplateRenderer() {
|
||||
return this.templateRenderer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
if (this.callAroundAdvisors.isEmpty()) {
|
||||
if (this.callAdvisors.isEmpty()) {
|
||||
throw new IllegalStateException("No CallAdvisors available to execute");
|
||||
}
|
||||
|
||||
var advisor = this.callAroundAdvisors.pop();
|
||||
var advisor = this.callAdvisors.pop();
|
||||
|
||||
var observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
@@ -116,52 +107,7 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
|
||||
return AdvisorObservationDocumentation.AI_ADVISOR
|
||||
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
|
||||
.observe(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof CallAdvisor callAdvisor) {
|
||||
return callAdvisor.adviseCall(chatClientRequest, this);
|
||||
}
|
||||
AdvisedResponse advisedResponse = advisor.aroundCall(AdvisedRequest.from(chatClientRequest), this);
|
||||
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
|
||||
observationContext.setChatClientResponse(chatClientResponse);
|
||||
return chatClientResponse;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #nextCall(ChatClientRequest)} instead
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
|
||||
Assert.notNull(advisedRequest, "the advisedRequest cannot be null");
|
||||
|
||||
if (this.callAroundAdvisors.isEmpty()) {
|
||||
throw new IllegalStateException("No AroundAdvisor available to execute");
|
||||
}
|
||||
|
||||
var advisor = this.callAroundAdvisors.pop();
|
||||
|
||||
var observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
.chatClientRequest(advisedRequest.toChatClientRequest(templateRenderer))
|
||||
.order(advisor.getOrder())
|
||||
.build();
|
||||
|
||||
return AdvisorObservationDocumentation.AI_ADVISOR
|
||||
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
|
||||
.observe(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof CallAdvisor callAdvisor) {
|
||||
ChatClientResponse chatClientResponse = callAdvisor
|
||||
.adviseCall(advisedRequest.toChatClientRequest(templateRenderer), this);
|
||||
return AdvisedResponse.from(chatClientResponse);
|
||||
}
|
||||
AdvisedResponse advisedResponse = advisor.aroundCall(advisedRequest, this);
|
||||
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
|
||||
observationContext.setChatClientResponse(chatClientResponse);
|
||||
return advisedResponse;
|
||||
});
|
||||
.observe(() -> advisor.adviseCall(chatClientRequest, this));
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -169,11 +115,11 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
return Flux.deferContextual(contextView -> {
|
||||
if (this.streamAroundAdvisors.isEmpty()) {
|
||||
if (this.streamAdvisors.isEmpty()) {
|
||||
return Flux.error(new IllegalStateException("No StreamAdvisors available to execute"));
|
||||
}
|
||||
|
||||
var advisor = this.streamAroundAdvisors.pop();
|
||||
var advisor = this.streamAdvisors.pop();
|
||||
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
@@ -187,77 +133,21 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
|
||||
|
||||
// @formatter:off
|
||||
return Flux.defer(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof StreamAdvisor streamAdvisor) {
|
||||
return streamAdvisor.adviseStream(chatClientRequest, this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
|
||||
}
|
||||
return advisor.aroundStream(AdvisedRequest.from(chatClientRequest), this)
|
||||
return Flux.defer(() -> advisor.adviseStream(chatClientRequest, this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation))
|
||||
.map(AdvisedResponse::toChatClientResponse);
|
||||
});
|
||||
// @formatter:on
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #nextStream(ChatClientRequest)} instead.
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
|
||||
Assert.notNull(advisedRequest, "the advisedRequest cannot be null");
|
||||
|
||||
return Flux.deferContextual(contextView -> {
|
||||
if (this.streamAroundAdvisors.isEmpty()) {
|
||||
return Flux.error(new IllegalStateException("No AroundAdvisor available to execute"));
|
||||
}
|
||||
|
||||
var advisor = this.streamAroundAdvisors.pop();
|
||||
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
.chatClientRequest(advisedRequest.toChatClientRequest(templateRenderer))
|
||||
.order(advisor.getOrder())
|
||||
.build();
|
||||
|
||||
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null,
|
||||
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
|
||||
|
||||
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
|
||||
|
||||
// @formatter:off
|
||||
return Flux.defer(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof StreamAdvisor streamAdvisor) {
|
||||
return streamAdvisor.adviseStream(advisedRequest.toChatClientRequest(templateRenderer), this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation))
|
||||
.map(AdvisedResponse::from);
|
||||
}
|
||||
|
||||
return advisor.aroundStream(advisedRequest, this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
|
||||
});
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)));
|
||||
// @formatter:on
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<CallAroundAdvisor> getCallAdvisors() {
|
||||
public List<CallAdvisor> getCallAdvisors() {
|
||||
return this.originalCallAdvisors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<StreamAroundAdvisor> getStreamAdvisors() {
|
||||
public List<StreamAdvisor> getStreamAdvisors() {
|
||||
return this.originalStreamAdvisors;
|
||||
}
|
||||
|
||||
@@ -270,16 +160,16 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
private final Deque<CallAroundAdvisor> callAroundAdvisors;
|
||||
private final Deque<CallAdvisor> callAdvisors;
|
||||
|
||||
private final Deque<StreamAroundAdvisor> streamAroundAdvisors;
|
||||
private final Deque<StreamAdvisor> streamAdvisors;
|
||||
|
||||
private TemplateRenderer templateRenderer;
|
||||
|
||||
public Builder(ObservationRegistry observationRegistry) {
|
||||
this.observationRegistry = observationRegistry;
|
||||
this.callAroundAdvisors = new ConcurrentLinkedDeque<>();
|
||||
this.streamAroundAdvisors = new ConcurrentLinkedDeque<>();
|
||||
this.callAdvisors = new ConcurrentLinkedDeque<>();
|
||||
this.streamAdvisors = new ConcurrentLinkedDeque<>();
|
||||
}
|
||||
|
||||
public Builder templateRenderer(TemplateRenderer templateRenderer) {
|
||||
@@ -296,22 +186,22 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
Assert.notNull(advisors, "the advisors must be non-null");
|
||||
Assert.noNullElements(advisors, "the advisors must not contain null elements");
|
||||
if (!CollectionUtils.isEmpty(advisors)) {
|
||||
List<CallAroundAdvisor> callAroundAdvisorList = advisors.stream()
|
||||
.filter(a -> a instanceof CallAroundAdvisor)
|
||||
.map(a -> (CallAroundAdvisor) a)
|
||||
List<CallAdvisor> callAroundAdvisorList = advisors.stream()
|
||||
.filter(a -> a instanceof CallAdvisor)
|
||||
.map(a -> (CallAdvisor) a)
|
||||
.toList();
|
||||
|
||||
if (!CollectionUtils.isEmpty(callAroundAdvisorList)) {
|
||||
callAroundAdvisorList.forEach(this.callAroundAdvisors::push);
|
||||
callAroundAdvisorList.forEach(this.callAdvisors::push);
|
||||
}
|
||||
|
||||
List<StreamAroundAdvisor> streamAroundAdvisorList = advisors.stream()
|
||||
.filter(a -> a instanceof StreamAroundAdvisor)
|
||||
.map(a -> (StreamAroundAdvisor) a)
|
||||
List<StreamAdvisor> streamAroundAdvisorList = advisors.stream()
|
||||
.filter(a -> a instanceof StreamAdvisor)
|
||||
.map(a -> (StreamAdvisor) a)
|
||||
.toList();
|
||||
|
||||
if (!CollectionUtils.isEmpty(streamAroundAdvisorList)) {
|
||||
streamAroundAdvisorList.forEach(this.streamAroundAdvisors::push);
|
||||
streamAroundAdvisorList.forEach(this.streamAdvisors::push);
|
||||
}
|
||||
|
||||
this.reOrder();
|
||||
@@ -323,20 +213,20 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
* (Re)orders the advisors in priority order based on their Ordered attribute.
|
||||
*/
|
||||
private void reOrder() {
|
||||
ArrayList<CallAroundAdvisor> callAdvisors = new ArrayList<>(this.callAroundAdvisors);
|
||||
ArrayList<CallAdvisor> callAdvisors = new ArrayList<>(this.callAdvisors);
|
||||
OrderComparator.sort(callAdvisors);
|
||||
this.callAroundAdvisors.clear();
|
||||
callAdvisors.forEach(this.callAroundAdvisors::addLast);
|
||||
this.callAdvisors.clear();
|
||||
callAdvisors.forEach(this.callAdvisors::addLast);
|
||||
|
||||
ArrayList<StreamAroundAdvisor> streamAdvisors = new ArrayList<>(this.streamAroundAdvisors);
|
||||
ArrayList<StreamAdvisor> streamAdvisors = new ArrayList<>(this.streamAdvisors);
|
||||
OrderComparator.sort(streamAdvisors);
|
||||
this.streamAroundAdvisors.clear();
|
||||
streamAdvisors.forEach(this.streamAroundAdvisors::addLast);
|
||||
this.streamAdvisors.clear();
|
||||
streamAdvisors.forEach(this.streamAdvisors::addLast);
|
||||
}
|
||||
|
||||
public DefaultAroundAdvisorChain build() {
|
||||
return new DefaultAroundAdvisorChain(this.observationRegistry, this.templateRenderer,
|
||||
this.callAroundAdvisors, this.streamAroundAdvisors);
|
||||
return new DefaultAroundAdvisorChain(this.observationRegistry, this.templateRenderer, this.callAdvisors,
|
||||
this.streamAdvisors);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -21,11 +21,11 @@ import java.util.List;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
@@ -57,58 +57,59 @@ public class MessageChatMemoryAdvisor extends AbstractChatMemoryAdvisor<ChatMemo
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
chatClientRequest = this.before(chatClientRequest);
|
||||
|
||||
advisedRequest = this.before(advisedRequest);
|
||||
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
|
||||
|
||||
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
|
||||
this.after(chatClientResponse);
|
||||
|
||||
this.observeAfter(advisedResponse);
|
||||
|
||||
return advisedResponse;
|
||||
return chatClientResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
|
||||
streamAdvisorChain, this::before);
|
||||
|
||||
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
|
||||
this::before);
|
||||
|
||||
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
|
||||
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
|
||||
String conversationId = this.doGetConversationId(chatClientRequest.context());
|
||||
|
||||
String conversationId = this.doGetConversationId(request.adviseContext());
|
||||
|
||||
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(request.adviseContext());
|
||||
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
|
||||
|
||||
// 1. Retrieve the chat memory for the current conversation.
|
||||
List<Message> memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize);
|
||||
|
||||
// 2. Advise the request messages list.
|
||||
List<Message> advisedMessages = new ArrayList<>(request.messages());
|
||||
advisedMessages.addAll(memoryMessages);
|
||||
List<Message> processedMessages = new ArrayList<>(memoryMessages);
|
||||
processedMessages.addAll(chatClientRequest.prompt().getInstructions());
|
||||
|
||||
// 3. Create a new request with the advised messages.
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.from(request).messages(advisedMessages).build();
|
||||
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
|
||||
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
|
||||
.build();
|
||||
|
||||
// 4. Add the new user input to the conversation memory.
|
||||
UserMessage userMessage = UserMessage.builder().text(request.userText()).media(request.media()).build();
|
||||
this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage);
|
||||
// 4. Add the new user message to the conversation memory.
|
||||
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
|
||||
this.getChatMemoryStore().add(conversationId, userMessage);
|
||||
|
||||
return advisedRequest;
|
||||
return processedChatClientRequest;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
|
||||
List<Message> assistantMessages = advisedResponse.response()
|
||||
.getResults()
|
||||
.stream()
|
||||
.map(g -> (Message) g.getOutput())
|
||||
.toList();
|
||||
|
||||
this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages);
|
||||
private void after(ChatClientResponse chatClientResponse) {
|
||||
List<Message> assistantMessages = new ArrayList<>();
|
||||
if (chatClientResponse.chatResponse() != null) {
|
||||
assistantMessages = chatClientResponse.chatResponse()
|
||||
.getResults()
|
||||
.stream()
|
||||
.map(g -> (Message) g.getOutput())
|
||||
.toList();
|
||||
}
|
||||
this.getChatMemoryStore().add(this.doGetConversationId(chatClientResponse.context()), assistantMessages);
|
||||
}
|
||||
|
||||
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory> {
|
||||
|
||||
@@ -16,19 +16,20 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.springframework.util.StringUtils;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
@@ -40,6 +41,7 @@ import org.springframework.ai.chat.model.MessageAggregator;
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Miloš Havránek
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class PromptChatMemoryAdvisor extends AbstractChatMemoryAdvisor<ChatMemory> {
|
||||
@@ -83,68 +85,68 @@ public class PromptChatMemoryAdvisor extends AbstractChatMemoryAdvisor<ChatMemor
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
chatClientRequest = this.before(chatClientRequest);
|
||||
|
||||
advisedRequest = this.before(advisedRequest);
|
||||
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
|
||||
|
||||
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
|
||||
this.after(chatClientResponse);
|
||||
|
||||
this.observeAfter(advisedResponse);
|
||||
|
||||
return advisedResponse;
|
||||
return chatClientResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
Flux<ChatClientResponse> chatClientResponses = this.doNextWithProtectFromBlockingBefore(chatClientRequest,
|
||||
streamAdvisorChain, this::before);
|
||||
|
||||
Flux<AdvisedResponse> advisedResponses = this.doNextWithProtectFromBlockingBefore(advisedRequest, chain,
|
||||
this::before);
|
||||
|
||||
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
|
||||
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::after);
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
private ChatClientRequest before(ChatClientRequest chatClientRequest) {
|
||||
String conversationId = this.doGetConversationId(chatClientRequest.context());
|
||||
int chatMemoryRetrieveSize = this.doGetChatMemoryRetrieveSize(chatClientRequest.context());
|
||||
|
||||
// 1. Advise system parameters.
|
||||
List<Message> memoryMessages = this.getChatMemoryStore()
|
||||
.get(this.doGetConversationId(request.adviseContext()),
|
||||
this.doGetChatMemoryRetrieveSize(request.adviseContext()));
|
||||
// 1. Retrieve the chat memory for the current conversation.
|
||||
List<Message> memoryMessages = this.getChatMemoryStore().get(conversationId, chatMemoryRetrieveSize);
|
||||
|
||||
String memory = (memoryMessages != null) ? memoryMessages.stream()
|
||||
// 2. Processed memory messages as a string.
|
||||
String memory = memoryMessages.stream()
|
||||
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
|
||||
.map(m -> m.getMessageType() + ":" + m.getText())
|
||||
.collect(Collectors.joining(System.lineSeparator())) : "";
|
||||
.collect(Collectors.joining(System.lineSeparator()));
|
||||
|
||||
Map<String, Object> advisedSystemParams = new HashMap<>(request.systemParams());
|
||||
advisedSystemParams.put("memory", memory);
|
||||
// 2. Augment the system message.
|
||||
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
|
||||
String augmentedSystemText = PromptTemplate.builder()
|
||||
.template(systemMessage.getText() + System.lineSeparator() + this.systemTextAdvise)
|
||||
.variables(Map.of("memory", memory))
|
||||
.build()
|
||||
.render();
|
||||
|
||||
// 2. Advise the system text.
|
||||
String systemText = request.systemText();
|
||||
String advisedSystemText = (StringUtils.hasText(systemText) ? systemText + System.lineSeparator() : "")
|
||||
+ this.systemTextAdvise;
|
||||
|
||||
// 3. Create a new request with the advised system text and parameters.
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.from(request)
|
||||
.systemText(advisedSystemText)
|
||||
.systemParams(advisedSystemParams)
|
||||
// 3. Create a new request with the augmented system message.
|
||||
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
|
||||
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
|
||||
.build();
|
||||
|
||||
// 4. Add the new user input to the conversation memory.
|
||||
UserMessage userMessage = UserMessage.builder().text(request.userText()).media(request.media()).build();
|
||||
this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage);
|
||||
// 4. Add the new user message to the conversation memory.
|
||||
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
|
||||
this.getChatMemoryStore().add(conversationId, userMessage);
|
||||
|
||||
return advisedRequest;
|
||||
return processedChatClientRequest;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
|
||||
List<Message> assistantMessages = advisedResponse.response()
|
||||
.getResults()
|
||||
.stream()
|
||||
.map(g -> (Message) g.getOutput())
|
||||
.toList();
|
||||
|
||||
this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages);
|
||||
private void after(ChatClientResponse chatClientResponse) {
|
||||
List<Message> assistantMessages = new ArrayList<>();
|
||||
if (chatClientResponse.chatResponse() != null) {
|
||||
assistantMessages = chatClientResponse.chatResponse()
|
||||
.getResults()
|
||||
.stream()
|
||||
.map(g -> (Message) g.getOutput())
|
||||
.toList();
|
||||
}
|
||||
this.getChatMemoryStore().add(this.doGetConversationId(chatClientResponse.context()), assistantMessages);
|
||||
}
|
||||
|
||||
public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder<ChatMemory> {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -17,15 +17,16 @@
|
||||
package org.springframework.ai.chat.client.advisor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
@@ -33,14 +34,15 @@ import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* A {@link CallAroundAdvisor} and {@link StreamAroundAdvisor} that filters out the
|
||||
* response if the user input contains any of the sensitive words.
|
||||
* An advisor that blocks the call to the model provider if the user input contains any of
|
||||
* the sensitive words.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
|
||||
public class SafeGuardAdvisor implements CallAdvisor, StreamAdvisor {
|
||||
|
||||
private static final String DEFAULT_FAILURE_RESPONSE = "I'm unable to respond to that due to sensitive content. Could we rephrase or discuss something else?";
|
||||
|
||||
@@ -73,32 +75,33 @@ public class SafeGuardAdvisor implements CallAroundAdvisor, StreamAroundAdvisor
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
if (!CollectionUtils.isEmpty(this.sensitiveWords)
|
||||
&& this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
|
||||
|
||||
return createFailureResponse(advisedRequest);
|
||||
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
|
||||
return createFailureResponse(chatClientRequest);
|
||||
}
|
||||
|
||||
return chain.nextAroundCall(advisedRequest);
|
||||
return callAdvisorChain.nextCall(chatClientRequest);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
if (!CollectionUtils.isEmpty(this.sensitiveWords)
|
||||
&& this.sensitiveWords.stream().anyMatch(w -> advisedRequest.userText().contains(w))) {
|
||||
return Flux.just(createFailureResponse(advisedRequest));
|
||||
&& this.sensitiveWords.stream().anyMatch(w -> chatClientRequest.prompt().getContents().contains(w))) {
|
||||
return Flux.just(createFailureResponse(chatClientRequest));
|
||||
}
|
||||
|
||||
return chain.nextAroundStream(advisedRequest);
|
||||
return streamAdvisorChain.nextStream(chatClientRequest);
|
||||
}
|
||||
|
||||
private AdvisedResponse createFailureResponse(AdvisedRequest advisedRequest) {
|
||||
return new AdvisedResponse(ChatResponse.builder()
|
||||
.generations(List.of(new Generation(new AssistantMessage(this.failureResponse))))
|
||||
.build(), advisedRequest.adviseContext());
|
||||
private ChatClientResponse createFailureResponse(ChatClientRequest chatClientRequest) {
|
||||
return ChatClientResponse.builder()
|
||||
.chatResponse(ChatResponse.builder()
|
||||
.generations(List.of(new Generation(new AssistantMessage(this.failureResponse))))
|
||||
.build())
|
||||
.context(Map.copyOf(chatClientRequest.context()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -18,39 +18,39 @@ package org.springframework.ai.chat.client.advisor;
|
||||
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.MessageAggregator;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.lang.Nullable;
|
||||
|
||||
/**
|
||||
* A simple logger advisor that logs the request and response messages.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
|
||||
public class SimpleLoggerAdvisor implements CallAdvisor, StreamAdvisor {
|
||||
|
||||
public static final Function<AdvisedRequest, String> DEFAULT_REQUEST_TO_STRING = request -> request.toString();
|
||||
public static final Function<ChatClientRequest, String> DEFAULT_REQUEST_TO_STRING = ChatClientRequest::toString;
|
||||
|
||||
public static final Function<ChatResponse, String> DEFAULT_RESPONSE_TO_STRING = response -> ModelOptionsUtils
|
||||
.toJsonStringPrettyPrinter(response);
|
||||
public static final Function<ChatResponse, String> DEFAULT_RESPONSE_TO_STRING = ModelOptionsUtils::toJsonStringPrettyPrinter;
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(SimpleLoggerAdvisor.class);
|
||||
|
||||
private final Function<AdvisedRequest, String> requestToString;
|
||||
private final Function<ChatClientRequest, String> requestToString;
|
||||
|
||||
private final Function<ChatResponse, String> responseToString;
|
||||
|
||||
private int order;
|
||||
private final int order;
|
||||
|
||||
public SimpleLoggerAdvisor() {
|
||||
this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, 0);
|
||||
@@ -60,13 +60,42 @@ public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvis
|
||||
this(DEFAULT_REQUEST_TO_STRING, DEFAULT_RESPONSE_TO_STRING, order);
|
||||
}
|
||||
|
||||
public SimpleLoggerAdvisor(Function<AdvisedRequest, String> requestToString,
|
||||
Function<ChatResponse, String> responseToString, int order) {
|
||||
this.requestToString = requestToString;
|
||||
this.responseToString = responseToString;
|
||||
public SimpleLoggerAdvisor(@Nullable Function<ChatClientRequest, String> requestToString,
|
||||
@Nullable Function<ChatResponse, String> responseToString, int order) {
|
||||
this.requestToString = requestToString != null ? requestToString : DEFAULT_REQUEST_TO_STRING;
|
||||
this.responseToString = responseToString != null ? responseToString : DEFAULT_RESPONSE_TO_STRING;
|
||||
this.order = order;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
logRequest(chatClientRequest);
|
||||
|
||||
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(chatClientRequest);
|
||||
|
||||
logResponse(chatClientResponse);
|
||||
|
||||
return chatClientResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
logRequest(chatClientRequest);
|
||||
|
||||
Flux<ChatClientResponse> chatClientResponses = streamAdvisorChain.nextStream(chatClientRequest);
|
||||
|
||||
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::logResponse);
|
||||
}
|
||||
|
||||
private void logRequest(ChatClientRequest request) {
|
||||
logger.debug("request: {}", this.requestToString.apply(request));
|
||||
}
|
||||
|
||||
private void logResponse(ChatClientResponse chatClientResponse) {
|
||||
logger.debug("response: {}", this.responseToString.apply(chatClientResponse.chatResponse()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
@@ -77,40 +106,45 @@ public class SimpleLoggerAdvisor implements CallAroundAdvisor, StreamAroundAdvis
|
||||
return this.order;
|
||||
}
|
||||
|
||||
private AdvisedRequest before(AdvisedRequest request) {
|
||||
logger.debug("request: {}", this.requestToString.apply(request));
|
||||
return request;
|
||||
}
|
||||
|
||||
private void observeAfter(AdvisedResponse advisedResponse) {
|
||||
logger.debug("response: {}", this.responseToString.apply(advisedResponse.response()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return SimpleLoggerAdvisor.class.getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
|
||||
advisedRequest = before(advisedRequest);
|
||||
|
||||
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
|
||||
|
||||
observeAfter(advisedResponse);
|
||||
|
||||
return advisedResponse;
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
public static class Builder {
|
||||
|
||||
advisedRequest = before(advisedRequest);
|
||||
private Function<ChatClientRequest, String> requestToString;
|
||||
|
||||
Flux<AdvisedResponse> advisedResponses = chain.nextAroundStream(advisedRequest);
|
||||
private Function<ChatResponse, String> responseToString;
|
||||
|
||||
private int order = 0;
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
public Builder requestToString(Function<ChatClientRequest, String> requestToString) {
|
||||
this.requestToString = requestToString;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder responseToString(Function<ChatResponse, String> responseToString) {
|
||||
this.responseToString = responseToString;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder order(int order) {
|
||||
this.order = order;
|
||||
return this;
|
||||
}
|
||||
|
||||
public SimpleLoggerAdvisor build() {
|
||||
return new SimpleLoggerAdvisor(requestToString, responseToString, order);
|
||||
}
|
||||
|
||||
return new MessageAggregator().aggregateAdvisedResponse(advisedResponses, this::observeAfter);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,457 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.template.TemplateRenderer;
|
||||
import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* The data of the chat client request that can be modified before the execution of the
|
||||
* ChatClient's call method
|
||||
*
|
||||
* @param chatModel the chat model used
|
||||
* @param userText the text provided by the user
|
||||
* @param systemText the text provided by the system
|
||||
* @param chatOptions the options for the chat
|
||||
* @param media the list of media items
|
||||
* @param toolNames the list of function names
|
||||
* @param toolCallbacks the list of function callbacks
|
||||
* @param messages the list of messages
|
||||
* @param userParams the map of user parameters
|
||||
* @param systemParams the map of system parameters
|
||||
* @param advisors the list of request response advisors
|
||||
* @param advisorParams the map of advisor parameters
|
||||
* @param adviseContext the map of advise context
|
||||
* @param toolContext the tool context
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @deprecated Use {@link ChatClientRequest} instead.
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public record AdvisedRequest(
|
||||
// @formatter:off
|
||||
ChatModel chatModel,
|
||||
String userText,
|
||||
@Nullable
|
||||
String systemText,
|
||||
@Nullable
|
||||
ChatOptions chatOptions,
|
||||
List<Media> media,
|
||||
List<String> toolNames,
|
||||
List<ToolCallback> toolCallbacks,
|
||||
List<Message> messages,
|
||||
Map<String, Object> userParams,
|
||||
Map<String, Object> systemParams,
|
||||
List<Advisor> advisors,
|
||||
@Deprecated // Not really used. Use "adviseContext" instead.
|
||||
Map<String, Object> advisorParams,
|
||||
Map<String, Object> adviseContext,
|
||||
Map<String, Object> toolContext
|
||||
// @formatter:on
|
||||
) {
|
||||
|
||||
public AdvisedRequest {
|
||||
Assert.notNull(chatModel, "chatModel cannot be null");
|
||||
Assert.isTrue(StringUtils.hasText(userText) || !CollectionUtils.isEmpty(messages),
|
||||
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
|
||||
Assert.notNull(media, "media cannot be null");
|
||||
Assert.noNullElements(media, "media cannot contain null elements");
|
||||
Assert.notNull(toolNames, "toolNames cannot be null");
|
||||
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
|
||||
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
|
||||
Assert.notNull(messages, "messages cannot be null");
|
||||
Assert.noNullElements(messages, "messages cannot contain null elements");
|
||||
Assert.notNull(userParams, "userParams cannot be null");
|
||||
Assert.noNullElements(userParams.keySet(), "userParams keys cannot contain null elements");
|
||||
Assert.noNullElements(userParams.values(), "userParams values cannot contain null elements");
|
||||
Assert.notNull(systemParams, "systemParams cannot be null");
|
||||
Assert.noNullElements(systemParams.keySet(), "systemParams keys cannot contain null elements");
|
||||
Assert.noNullElements(systemParams.values(), "systemParams values cannot contain null elements");
|
||||
Assert.notNull(advisors, "advisors cannot be null");
|
||||
Assert.noNullElements(advisors, "advisors cannot contain null elements");
|
||||
Assert.notNull(advisorParams, "advisorParams cannot be null");
|
||||
Assert.noNullElements(advisorParams.keySet(), "advisorParams keys cannot contain null elements");
|
||||
Assert.noNullElements(advisorParams.values(), "advisorParams values cannot contain null elements");
|
||||
Assert.notNull(adviseContext, "adviseContext cannot be null");
|
||||
Assert.noNullElements(adviseContext.keySet(), "adviseContext keys cannot contain null elements");
|
||||
Assert.noNullElements(adviseContext.values(), "adviseContext values cannot contain null elements");
|
||||
Assert.notNull(toolContext, "toolContext cannot be null");
|
||||
Assert.noNullElements(toolContext.keySet(), "toolContext keys cannot contain null elements");
|
||||
Assert.noNullElements(toolContext.values(), "toolContext values cannot contain null elements");
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static Builder from(AdvisedRequest from) {
|
||||
Assert.notNull(from, "AdvisedRequest cannot be null");
|
||||
|
||||
Builder builder = new Builder();
|
||||
builder.chatModel = from.chatModel;
|
||||
builder.userText = from.userText;
|
||||
builder.systemText = from.systemText;
|
||||
builder.chatOptions = from.chatOptions;
|
||||
builder.media = from.media;
|
||||
builder.toolNames = from.toolNames;
|
||||
builder.toolCallbacks = from.toolCallbacks;
|
||||
builder.messages = from.messages;
|
||||
builder.userParams = from.userParams;
|
||||
builder.systemParams = from.systemParams;
|
||||
builder.advisors = from.advisors;
|
||||
builder.advisorParams = from.advisorParams;
|
||||
builder.adviseContext = from.adviseContext;
|
||||
builder.toolContext = from.toolContext;
|
||||
return builder;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static AdvisedRequest from(ChatClientRequest from) {
|
||||
Assert.notNull(from, "ChatClientRequest cannot be null");
|
||||
|
||||
List<Message> messages = new LinkedList<>(from.prompt().getInstructions());
|
||||
|
||||
Builder builder = new Builder();
|
||||
if (from.context().get(ChatClientAttributes.CHAT_MODEL.getKey()) instanceof ChatModel chatModel) {
|
||||
builder.chatModel = chatModel;
|
||||
}
|
||||
|
||||
if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof UserMessage userMessage) {
|
||||
builder.userText = userMessage.getText();
|
||||
builder.media = userMessage.getMedia();
|
||||
messages.remove(messages.size() - 1);
|
||||
}
|
||||
if (from.context().get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map<?, ?> contextUserParams) {
|
||||
builder.userParams = (Map<String, Object>) contextUserParams;
|
||||
}
|
||||
|
||||
if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof SystemMessage systemMessage) {
|
||||
builder.systemText = systemMessage.getText();
|
||||
messages.remove(messages.size() - 1);
|
||||
}
|
||||
if (from.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map<?, ?> contextSystemParams) {
|
||||
builder.systemParams = (Map<String, Object>) contextSystemParams;
|
||||
}
|
||||
|
||||
builder.messages = messages;
|
||||
|
||||
builder.chatOptions = Objects.requireNonNullElse(from.prompt().getOptions(), ChatOptions.builder().build());
|
||||
if (from.prompt().getOptions() instanceof ToolCallingChatOptions options) {
|
||||
builder.toolNames = options.getToolNames().stream().toList();
|
||||
builder.toolCallbacks = options.getToolCallbacks();
|
||||
builder.toolContext = options.getToolContext();
|
||||
}
|
||||
|
||||
if (from.context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List<?> advisors) {
|
||||
builder.advisors = (List<Advisor>) advisors;
|
||||
}
|
||||
builder.advisorParams = Map.of();
|
||||
builder.adviseContext = from.context();
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
|
||||
Assert.notNull(contextTransform, "contextTransform cannot be null");
|
||||
return from(this)
|
||||
.adviseContext(Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext))))
|
||||
.build();
|
||||
}
|
||||
|
||||
public ChatClientRequest toChatClientRequest() {
|
||||
return toChatClientRequest(StTemplateRenderer.builder().build());
|
||||
}
|
||||
|
||||
public ChatClientRequest toChatClientRequest(TemplateRenderer templateRenderer) {
|
||||
return ChatClientRequest.builder()
|
||||
.prompt(toPrompt(templateRenderer))
|
||||
.context(this.adviseContext)
|
||||
.context(ChatClientAttributes.ADVISORS.getKey(), this.advisors)
|
||||
.context(ChatClientAttributes.CHAT_MODEL.getKey(), this.chatModel)
|
||||
.context(ChatClientAttributes.USER_PARAMS.getKey(), this.userParams)
|
||||
.context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), this.systemParams)
|
||||
.build();
|
||||
}
|
||||
|
||||
public Prompt toPrompt() {
|
||||
return toPrompt(StTemplateRenderer.builder().build());
|
||||
}
|
||||
|
||||
public Prompt toPrompt(TemplateRenderer templateRenderer) {
|
||||
var messages = new ArrayList<>(this.messages());
|
||||
|
||||
String processedSystemText = this.systemText();
|
||||
if (StringUtils.hasText(processedSystemText)) {
|
||||
if (!CollectionUtils.isEmpty(this.systemParams())) {
|
||||
processedSystemText = PromptTemplate.builder()
|
||||
.template(processedSystemText)
|
||||
.variables(this.systemParams())
|
||||
.renderer(templateRenderer)
|
||||
.build()
|
||||
.render();
|
||||
}
|
||||
messages.add(new SystemMessage(processedSystemText));
|
||||
}
|
||||
|
||||
if (StringUtils.hasText(this.userText())) {
|
||||
Map<String, Object> userParams = new HashMap<>(this.userParams());
|
||||
String processedUserText = this.userText();
|
||||
if (!CollectionUtils.isEmpty(userParams)) {
|
||||
processedUserText = PromptTemplate.builder()
|
||||
.template(processedUserText)
|
||||
.variables(userParams)
|
||||
.renderer(templateRenderer)
|
||||
.build()
|
||||
.render();
|
||||
}
|
||||
messages.add(UserMessage.builder().text(processedUserText).media(this.media()).build());
|
||||
}
|
||||
|
||||
if (this.chatOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
|
||||
if (!this.toolNames().isEmpty()) {
|
||||
toolCallingChatOptions.setToolNames(new HashSet<>(this.toolNames()));
|
||||
}
|
||||
if (!this.toolCallbacks().isEmpty()) {
|
||||
toolCallingChatOptions.setToolCallbacks(this.toolCallbacks());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(this.toolContext())) {
|
||||
toolCallingChatOptions.setToolContext(this.toolContext());
|
||||
}
|
||||
}
|
||||
|
||||
return new Prompt(messages, this.chatOptions());
|
||||
}
|
||||
|
||||
/**
|
||||
* Builder for {@link AdvisedRequest}.
|
||||
*/
|
||||
public static final class Builder {
|
||||
|
||||
private ChatModel chatModel;
|
||||
|
||||
private String userText;
|
||||
|
||||
private String systemText;
|
||||
|
||||
private ChatOptions chatOptions;
|
||||
|
||||
private List<Media> media = List.of();
|
||||
|
||||
private List<String> toolNames = List.of();
|
||||
|
||||
private List<ToolCallback> toolCallbacks = List.of();
|
||||
|
||||
private List<Message> messages = List.of();
|
||||
|
||||
private Map<String, Object> userParams = Map.of();
|
||||
|
||||
private Map<String, Object> systemParams = Map.of();
|
||||
|
||||
private List<Advisor> advisors = List.of();
|
||||
|
||||
private Map<String, Object> advisorParams = Map.of();
|
||||
|
||||
private Map<String, Object> adviseContext = Map.of();
|
||||
|
||||
public Map<String, Object> toolContext = Map.of();
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the chat model.
|
||||
* @param chatModel the chat model
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder chatModel(ChatModel chatModel) {
|
||||
this.chatModel = chatModel;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the user text.
|
||||
* @param userText the user text
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder userText(String userText) {
|
||||
this.userText = userText;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the system text.
|
||||
* @param systemText the system text
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder systemText(String systemText) {
|
||||
this.systemText = systemText;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the chat options.
|
||||
* @param chatOptions the chat options
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder chatOptions(ChatOptions chatOptions) {
|
||||
this.chatOptions = chatOptions;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the media.
|
||||
* @param media the media
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder media(List<Media> media) {
|
||||
this.media = media;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the tool names.
|
||||
* @param toolNames the function names
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder toolNames(List<String> toolNames) {
|
||||
this.toolNames = toolNames;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the tool callbacks.
|
||||
* @param toolCallbacks the tool callbacks
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder functionCallbacks(List<ToolCallback> toolCallbacks) {
|
||||
this.toolCallbacks = toolCallbacks;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the messages.
|
||||
* @param messages the messages
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder messages(List<Message> messages) {
|
||||
this.messages = messages;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the user params.
|
||||
* @param userParams the user params
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder userParams(Map<String, Object> userParams) {
|
||||
this.userParams = userParams;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the system params.
|
||||
* @param systemParams the system params
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder systemParams(Map<String, Object> systemParams) {
|
||||
this.systemParams = systemParams;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advisors.
|
||||
* @param advisors the advisors
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder advisors(List<Advisor> advisors) {
|
||||
this.advisors = advisors;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advisor params.
|
||||
* @param advisorParams the advisor params
|
||||
* @return this {@link Builder} instance
|
||||
* @deprecated in favor of {@link #adviseContext(Map)}
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorParams(Map<String, Object> advisorParams) {
|
||||
this.advisorParams = advisorParams;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advise context.
|
||||
* @param adviseContext the advise context
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder adviseContext(Map<String, Object> adviseContext) {
|
||||
this.adviseContext = adviseContext;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the tool context.
|
||||
* @param toolContext the tool context
|
||||
* @return this {@link Builder} instance
|
||||
*/
|
||||
public Builder toolContext(Map<String, Object> toolContext) {
|
||||
this.toolContext = toolContext;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the {@link AdvisedRequest} instance.
|
||||
* @return a new {@link AdvisedRequest} instance
|
||||
*/
|
||||
public AdvisedRequest build() {
|
||||
return new AdvisedRequest(this.chatModel, this.userText, this.systemText, this.chatOptions, this.media,
|
||||
this.toolNames, this.toolCallbacks, this.messages, this.userParams, this.systemParams,
|
||||
this.advisors, this.advisorParams, this.adviseContext, this.toolContext);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* The data of the chat client response that can be modified before the call returns.
|
||||
*
|
||||
* @param response the chat response
|
||||
* @param adviseContext the context to advise the response
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @deprecated Use {@link ChatClientResponse} instead.
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@Deprecated
|
||||
public record AdvisedResponse(@Nullable ChatResponse response, Map<String, Object> adviseContext) {
|
||||
|
||||
/**
|
||||
* Create a new {@link AdvisedResponse} instance.
|
||||
* @param response the chat response
|
||||
* @param adviseContext the context to advise the response
|
||||
*/
|
||||
public AdvisedResponse {
|
||||
Assert.notNull(adviseContext, "adviseContext cannot be null");
|
||||
Assert.noNullElements(adviseContext.keySet(), "adviseContext keys cannot be null");
|
||||
Assert.noNullElements(adviseContext.values(), "adviseContext values cannot be null");
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new {@link Builder} instance.
|
||||
* @return a new {@link Builder} instance
|
||||
*/
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new {@link Builder} instance from the provided {@link AdvisedResponse}.
|
||||
* @param advisedResponse the advised response to copy
|
||||
* @return a new {@link Builder} instance
|
||||
*/
|
||||
public static Builder from(AdvisedResponse advisedResponse) {
|
||||
Assert.notNull(advisedResponse, "advisedResponse cannot be null");
|
||||
return new Builder().response(advisedResponse.response).adviseContext(advisedResponse.adviseContext);
|
||||
}
|
||||
|
||||
public static AdvisedResponse from(ChatClientResponse chatClientResponse) {
|
||||
Assert.notNull(chatClientResponse, "chatClientResponse cannot be null");
|
||||
return new AdvisedResponse(chatClientResponse.chatResponse(), chatClientResponse.context());
|
||||
}
|
||||
|
||||
public ChatClientResponse toChatClientResponse() {
|
||||
return new ChatClientResponse(this.response, this.adviseContext);
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the context of the advised response.
|
||||
* @param contextTransform the function to transform the context
|
||||
* @return the updated advised response
|
||||
*/
|
||||
public AdvisedResponse updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
|
||||
Assert.notNull(contextTransform, "contextTransform cannot be null");
|
||||
return new AdvisedResponse(this.response,
|
||||
Collections.unmodifiableMap(contextTransform.apply(new HashMap<>(this.adviseContext))));
|
||||
}
|
||||
|
||||
/**
|
||||
* Builder for {@link AdvisedResponse}.
|
||||
*/
|
||||
public static final class Builder {
|
||||
|
||||
@Nullable
|
||||
private ChatResponse response;
|
||||
|
||||
private Map<String, Object> adviseContext;
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the chat response.
|
||||
* @param response the chat response
|
||||
* @return the builder
|
||||
*/
|
||||
public Builder response(@Nullable ChatResponse response) {
|
||||
this.response = response;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the context to advise the response.
|
||||
* @param adviseContext the context to advise the response
|
||||
* @return the builder
|
||||
*/
|
||||
public Builder adviseContext(Map<String, Object> adviseContext) {
|
||||
this.adviseContext = adviseContext;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the {@link AdvisedResponse}.
|
||||
* @return the advised response
|
||||
*/
|
||||
public AdvisedResponse build() {
|
||||
return new AdvisedResponse(this.response, this.adviseContext);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
/*
|
||||
* Copyright 2025-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import java.util.function.Predicate;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.AdvisorUtils;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* A stream utility class to provide support methods handling {@link AdvisedResponse}.
|
||||
*
|
||||
* @deprecated in favour of {@link AdvisorUtils}.
|
||||
*/
|
||||
@Deprecated
|
||||
public final class AdvisedResponseStreamUtils {
|
||||
|
||||
private AdvisedResponseStreamUtils() {
|
||||
// Avoids instantiation
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a predicate that checks whether the provided {@link AdvisedResponse}
|
||||
* contains a {@link ChatResponse} with at least one result having a non-empty finish
|
||||
* reason in its metadata.
|
||||
* @return a {@link Predicate} that evaluates whether the finish reason exists within
|
||||
* the response metadata.
|
||||
*/
|
||||
public static Predicate<AdvisedResponse> onFinishReason() {
|
||||
return advisedResponse -> {
|
||||
ChatResponse chatResponse = advisedResponse.response();
|
||||
return chatResponse != null && chatResponse.getResults() != null
|
||||
&& chatResponse.getResults()
|
||||
.stream()
|
||||
.anyMatch(result -> result != null && result.getMetadata() != null
|
||||
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
@@ -31,10 +31,10 @@ import org.springframework.util.Assert;
|
||||
* {@link StreamAdvisor}, reducing the boilerplate code needed to implement an advisor.
|
||||
* <p>
|
||||
* It provides default implementations for the
|
||||
* {@link #adviseCall(ChatClientRequest, CallAroundAdvisorChain)} and
|
||||
* {@link #adviseStream(ChatClientRequest, StreamAroundAdvisorChain)} methods, delegating
|
||||
* the actual logic to the {@link #before(ChatClientRequest, AdvisorChain advisorChain)}
|
||||
* and {@link #after(ChatClientResponse, AdvisorChain advisorChain)} methods.
|
||||
* {@link #adviseCall(ChatClientRequest, CallAdvisorChain)} and
|
||||
* {@link #adviseStream(ChatClientRequest, StreamAdvisorChain)} methods, delegating the
|
||||
* actual logic to the {@link #before(ChatClientRequest, AdvisorChain advisorChain)} and
|
||||
* {@link #after(ChatClientResponse, AdvisorChain advisorChain)} methods.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
@@ -44,82 +44,35 @@ public interface BaseAdvisor extends CallAdvisor, StreamAdvisor {
|
||||
Scheduler DEFAULT_SCHEDULER = Schedulers.boundedElastic();
|
||||
|
||||
@Override
|
||||
default ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) {
|
||||
default ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
Assert.notNull(chain, "chain cannot be null");
|
||||
Assert.notNull(callAdvisorChain, "callAdvisorChain cannot be null");
|
||||
|
||||
ChatClientRequest processedChatClientRequest = before(chatClientRequest, chain);
|
||||
ChatClientResponse chatClientResponse;
|
||||
if (chain instanceof CallAdvisorChain callAdvisorChain) {
|
||||
chatClientResponse = callAdvisorChain.nextCall(processedChatClientRequest);
|
||||
}
|
||||
else {
|
||||
chatClientResponse = chain.nextAroundCall(AdvisedRequest.from(processedChatClientRequest))
|
||||
.toChatClientResponse();
|
||||
}
|
||||
return after(chatClientResponse, chain);
|
||||
ChatClientRequest processedChatClientRequest = before(chatClientRequest, callAdvisorChain);
|
||||
ChatClientResponse chatClientResponse = callAdvisorChain.nextCall(processedChatClientRequest);
|
||||
return after(chatClientResponse, callAdvisorChain);
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
|
||||
Assert.notNull(chain, "chain cannot be null");
|
||||
|
||||
AdvisedRequest processedAdvisedRequest = before(advisedRequest);
|
||||
AdvisedResponse advisedResponse = chain.nextAroundCall(processedAdvisedRequest);
|
||||
return after(advisedResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
default Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain) {
|
||||
default Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
Assert.notNull(chain, "chain cannot be null");
|
||||
Assert.notNull(streamAdvisorChain, "streamAdvisorChain cannot be null");
|
||||
Assert.notNull(getScheduler(), "scheduler cannot be null");
|
||||
|
||||
Flux<ChatClientResponse> chatClientResponseFlux;
|
||||
if (chain instanceof StreamAdvisorChain streamAdvisorChain) {
|
||||
chatClientResponseFlux = Mono.just(chatClientRequest)
|
||||
.publishOn(getScheduler())
|
||||
.map(request -> this.before(request, streamAdvisorChain))
|
||||
.flatMapMany(streamAdvisorChain::nextStream);
|
||||
}
|
||||
else {
|
||||
chatClientResponseFlux = Mono.just(AdvisedRequest.from(chatClientRequest))
|
||||
.publishOn(getScheduler())
|
||||
.map(this::before)
|
||||
.flatMapMany(chain::nextAroundStream)
|
||||
.map(AdvisedResponse::toChatClientResponse);
|
||||
}
|
||||
Flux<ChatClientResponse> chatClientResponseFlux = Mono.just(chatClientRequest)
|
||||
.publishOn(getScheduler())
|
||||
.map(request -> this.before(request, streamAdvisorChain))
|
||||
.flatMapMany(streamAdvisorChain::nextStream);
|
||||
|
||||
return chatClientResponseFlux.map(response -> {
|
||||
if (AdvisorUtils.onFinishReason().test(response)) {
|
||||
response = after(response, chain);
|
||||
response = after(response, streamAdvisorChain);
|
||||
}
|
||||
return response;
|
||||
}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
Assert.notNull(advisedRequest, "advisedRequest cannot be null");
|
||||
Assert.notNull(chain, "chain cannot be null");
|
||||
Assert.notNull(getScheduler(), "scheduler cannot be null");
|
||||
|
||||
Flux<AdvisedResponse> advisedResponses = Mono.just(advisedRequest)
|
||||
.publishOn(getScheduler())
|
||||
.map(this::before)
|
||||
.flatMapMany(chain::nextAroundStream);
|
||||
|
||||
return advisedResponses.map(ar -> {
|
||||
if (AdvisedResponseStreamUtils.onFinishReason().test(ar)) {
|
||||
ar = after(ar);
|
||||
}
|
||||
return ar;
|
||||
}).onErrorResume(error -> Flux.error(new IllegalStateException("Stream processing failed", error)));
|
||||
}
|
||||
|
||||
@Override
|
||||
default String getName() {
|
||||
return this.getClass().getSimpleName();
|
||||
@@ -128,32 +81,12 @@ public interface BaseAdvisor extends CallAdvisor, StreamAdvisor {
|
||||
/**
|
||||
* Logic to be executed before the rest of the advisor chain is called.
|
||||
*/
|
||||
default ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
return before(AdvisedRequest.from(chatClientRequest)).toChatClientRequest();
|
||||
}
|
||||
ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain);
|
||||
|
||||
/**
|
||||
* Logic to be executed after the rest of the advisor chain is called.
|
||||
*/
|
||||
default ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain) {
|
||||
Assert.notNull(chatClientResponse, "chatClientResponse cannot be null");
|
||||
return after(AdvisedResponse.from(chatClientResponse)).toChatClientResponse();
|
||||
}
|
||||
|
||||
/**
|
||||
* Logic to be executed before the rest of the advisor chain is called.
|
||||
* @deprecated in favor of {@link #before(ChatClientRequest,AdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
AdvisedRequest before(AdvisedRequest request);
|
||||
|
||||
/**
|
||||
* Logic to be executed after the rest of the advisor chain is called.
|
||||
* @deprecated in favor of {@link #after(ChatClientResponse,AdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
AdvisedResponse after(AdvisedResponse advisedResponse);
|
||||
ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorChain advisorChain);
|
||||
|
||||
/**
|
||||
* Scheduler used for processing the advisor logic when streaming.
|
||||
|
||||
@@ -16,9 +16,6 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.template.TemplateRenderer;
|
||||
import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
|
||||
/**
|
||||
* A base interface for advisor chains that can be used to chain multiple advisors
|
||||
* together, both for call and stream advisors.
|
||||
@@ -28,8 +25,4 @@ import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
*/
|
||||
public interface BaseAdvisorChain extends CallAdvisorChain, StreamAdvisorChain {
|
||||
|
||||
default TemplateRenderer getTemplateRenderer() {
|
||||
return StTemplateRenderer.builder().build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,20 +22,13 @@ import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
/**
|
||||
* Advisor for execution flows ultimately resulting in a call to an AI model
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface CallAdvisor extends CallAroundAdvisor {
|
||||
public interface CallAdvisor extends Advisor {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #adviseCall(ChatClientRequest, CallAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
ChatClientResponse chatClientResponse = adviseCall(advisedRequest.toChatClientRequest(), chain);
|
||||
return AdvisedResponse.from(chatClientResponse);
|
||||
}
|
||||
|
||||
ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain);
|
||||
ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain);
|
||||
|
||||
}
|
||||
|
||||
@@ -25,22 +25,23 @@ import java.util.List;
|
||||
* A chain of {@link CallAdvisor} instances orchestrating the execution of a
|
||||
* {@link ChatClientRequest} on the next {@link CallAdvisor} in the chain.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface CallAdvisorChain extends CallAroundAdvisorChain {
|
||||
public interface CallAdvisorChain extends AdvisorChain {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #nextCall(ChatClientRequest)}
|
||||
* Invokes the next {@link CallAdvisor} in the {@link CallAdvisorChain} with the given
|
||||
* request.
|
||||
*/
|
||||
@Deprecated
|
||||
default AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
|
||||
ChatClientResponse chatClientResponse = nextCall(advisedRequest.toChatClientRequest());
|
||||
return AdvisedResponse.from(chatClientResponse);
|
||||
}
|
||||
|
||||
ChatClientResponse nextCall(ChatClientRequest chatClientRequest);
|
||||
|
||||
List<CallAroundAdvisor> getCallAdvisors();
|
||||
/**
|
||||
* Returns the list of all the {@link CallAdvisor} instances included in this chain at
|
||||
* the time of its creation.
|
||||
*/
|
||||
List<CallAdvisor> getCallAdvisors();
|
||||
|
||||
}
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
|
||||
/**
|
||||
* Around advisor that wraps the ChatModel#call(Prompt) method.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link CallAdvisor}
|
||||
*/
|
||||
@Deprecated
|
||||
public interface CallAroundAdvisor extends Advisor {
|
||||
|
||||
/**
|
||||
* Around advice that wraps the ChatModel#call(Prompt) method.
|
||||
* @param advisedRequest the advised request
|
||||
* @param chain the advisor chain
|
||||
* @return the response
|
||||
* @deprecated in favor of
|
||||
* {@link CallAdvisor#adviseCall(ChatClientRequest, CallAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain);
|
||||
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
|
||||
/**
|
||||
* The Call Around Advisor Chain is used to invoke the next Around Advisor in the chain.
|
||||
* Used for non-streaming responses.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link CallAdvisorChain}
|
||||
*/
|
||||
@Deprecated
|
||||
public interface CallAroundAdvisorChain extends AdvisorChain {
|
||||
|
||||
/**
|
||||
* Invokes the next Around Advisor in the CallAroundAdvisorChain with the given
|
||||
* request.
|
||||
* @param advisedRequest the request containing the data to be processed by the next
|
||||
* advisor in the chain.
|
||||
* @return the response generated by the next advisor in the chain.
|
||||
* @deprecated in favor of {@link CallAdvisorChain#nextCall(ChatClientRequest)}
|
||||
*/
|
||||
@Deprecated
|
||||
AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest);
|
||||
|
||||
}
|
||||
@@ -23,20 +23,13 @@ import reactor.core.publisher.Flux;
|
||||
/**
|
||||
* Advisor for execution flows ultimately resulting in a streaming call to an AI model.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface StreamAdvisor extends StreamAroundAdvisor {
|
||||
public interface StreamAdvisor extends Advisor {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #adviseStream(ChatClientRequest, StreamAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
Flux<ChatClientResponse> chatClientResponse = adviseStream(advisedRequest.toChatClientRequest(), chain);
|
||||
return chatClientResponse.map(AdvisedResponse::from);
|
||||
}
|
||||
|
||||
Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain);
|
||||
Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAdvisorChain streamAdvisorChain);
|
||||
|
||||
}
|
||||
|
||||
@@ -26,22 +26,23 @@ import java.util.List;
|
||||
* A chain of {@link StreamAdvisor} instances orchestrating the execution of a
|
||||
* {@link ChatClientRequest} on the next {@link StreamAdvisor} in the chain.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface StreamAdvisorChain extends StreamAroundAdvisorChain {
|
||||
public interface StreamAdvisorChain extends AdvisorChain {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #nextStream(ChatClientRequest)}
|
||||
* Invokes the next {@link StreamAdvisor} in the {@link StreamAdvisorChain} with the
|
||||
* given request.
|
||||
*/
|
||||
@Deprecated
|
||||
default Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
|
||||
Flux<ChatClientResponse> chatClientResponse = nextStream(advisedRequest.toChatClientRequest());
|
||||
return chatClientResponse.map(AdvisedResponse::from);
|
||||
}
|
||||
|
||||
Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest);
|
||||
|
||||
List<StreamAroundAdvisor> getStreamAdvisors();
|
||||
/**
|
||||
* Returns the list of all the {@link StreamAdvisor} instances included in this chain
|
||||
* at the time of its creation.
|
||||
*/
|
||||
List<StreamAdvisor> getStreamAdvisors();
|
||||
|
||||
}
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* Around advisor that runs around stream based requests.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link StreamAdvisor}
|
||||
*/
|
||||
@Deprecated
|
||||
public interface StreamAroundAdvisor extends Advisor {
|
||||
|
||||
/**
|
||||
* Around advice that wraps the invocation of the advised request.
|
||||
* @param advisedRequest the advised request
|
||||
* @param chain the chain of advisors to execute
|
||||
* @return the result of the advised request
|
||||
* @deprecated in favor of
|
||||
* {@link StreamAdvisor#adviseStream(ChatClientRequest, StreamAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain);
|
||||
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* The StreamAroundAdvisorChain is used to delegate the call to the next
|
||||
* StreamAroundAdvisor in the chain. Used for streaming responses.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link StreamAdvisorChain}
|
||||
*/
|
||||
@Deprecated
|
||||
public interface StreamAroundAdvisorChain extends AdvisorChain {
|
||||
|
||||
/**
|
||||
* This method delegates the call to the next StreamAroundAdvisor in the chain and is
|
||||
* used for streaming responses.
|
||||
* @param advisedRequest the request containing data of the chat client that can be
|
||||
* modified before execution
|
||||
* @return a Flux stream of AdvisedResponse objects
|
||||
* @deprecated in favor of {@link StreamAdvisorChain#nextStream(ChatClientRequest)}
|
||||
*/
|
||||
@Deprecated
|
||||
Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest);
|
||||
|
||||
}
|
||||
@@ -16,18 +16,12 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.observation;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import io.micrometer.observation.Observation;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Context used to store metadata for chat client advisors.
|
||||
@@ -47,41 +41,6 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
@Nullable
|
||||
private ChatClientResponse chatClientResponse;
|
||||
|
||||
/**
|
||||
* the shared data between the advisors in the chain. It is shared between all request
|
||||
* and response advising points of all advisors in the chain.
|
||||
*/
|
||||
@Nullable
|
||||
private Map<String, Object> advisorResponseContext;
|
||||
|
||||
/**
|
||||
* Create a new {@link AdvisorObservationContext}.
|
||||
* @param advisorName the advisor name
|
||||
* @param advisorType the advisor type
|
||||
* @param advisorRequest the advised request
|
||||
* @param advisorRequestContext the shared data between the advisors in the chain
|
||||
* @param advisorResponseContext the shared data between the advisors in the chain
|
||||
* @param order the order of the advisor in the advisor chain
|
||||
* @deprecated use the builder instead
|
||||
*/
|
||||
@Deprecated
|
||||
public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable AdvisedRequest advisorRequest,
|
||||
@Nullable Map<String, Object> advisorRequestContext, @Nullable Map<String, Object> advisorResponseContext,
|
||||
int order) {
|
||||
Assert.hasText(advisorName, "advisorName cannot be null or empty");
|
||||
|
||||
this.advisorName = advisorName;
|
||||
this.chatClientRequest = advisorRequest != null ? advisorRequest.toChatClientRequest()
|
||||
: ChatClientRequest.builder().prompt(new Prompt()).build();
|
||||
if (!CollectionUtils.isEmpty(advisorRequestContext)) {
|
||||
this.chatClientRequest.context().putAll(advisorRequestContext);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(advisorResponseContext)) {
|
||||
this.chatClientResponse = ChatClientResponse.builder().context(advisorResponseContext).build();
|
||||
}
|
||||
this.order = order;
|
||||
}
|
||||
|
||||
AdvisorObservationContext(String advisorName, ChatClientRequest chatClientRequest, int order) {
|
||||
Assert.hasText(advisorName, "advisorName cannot be null or empty");
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
@@ -120,111 +79,6 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
this.chatClientResponse = chatClientResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* The type of the advisor.
|
||||
* @return the type of the advisor
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
public Type getAdvisorType() {
|
||||
return Type.AROUND;
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the advisor in the advisor chain.
|
||||
* @return the order of the advisor in the advisor chain
|
||||
* @deprecated not used anymore
|
||||
*/
|
||||
@Deprecated
|
||||
public AdvisedRequest getAdvisedRequest() {
|
||||
return AdvisedRequest.from(this.chatClientRequest);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the {@link AdvisedRequest} data to be advised. Represents the row
|
||||
* {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}.
|
||||
* @param advisedRequest the advised request
|
||||
* @deprecated immutable object, use the builder instead to create a new instance
|
||||
*/
|
||||
@Deprecated
|
||||
public void setAdvisedRequest(@Nullable AdvisedRequest advisedRequest) {
|
||||
throw new IllegalStateException(
|
||||
"The AdvisedRequest is immutable. Build a new AdvisorObservationContext instead.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @return the shared data between the advisors in the chain
|
||||
* @deprecated use {@link #getChatClientRequest()} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Map<String, Object> getAdvisorRequestContext() {
|
||||
return this.chatClientRequest.context();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @param advisorRequestContext the shared data between the advisors in the chain
|
||||
* @deprecated not supported anymore, use {@link #getChatClientRequest()} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public void setAdvisorRequestContext(@Nullable Map<String, Object> advisorRequestContext) {
|
||||
if (!CollectionUtils.isEmpty(advisorRequestContext)) {
|
||||
this.chatClientRequest.context().putAll(advisorRequestContext);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @return the shared data between the advisors in the chain
|
||||
* @deprecated use {@link #getChatClientResponse()} instead
|
||||
*/
|
||||
@Nullable
|
||||
@Deprecated
|
||||
public Map<String, Object> getAdvisorResponseContext() {
|
||||
if (this.chatClientResponse != null) {
|
||||
return this.chatClientResponse.context();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @param advisorResponseContext the shared data between the advisors in the chain
|
||||
* @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public void setAdvisorResponseContext(@Nullable Map<String, Object> advisorResponseContext) {
|
||||
this.advisorResponseContext = advisorResponseContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* The type of the advisor.
|
||||
*
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
public enum Type {
|
||||
|
||||
/**
|
||||
* The advisor is called before the advised request is executed.
|
||||
*/
|
||||
BEFORE,
|
||||
/**
|
||||
* The advisor is called after the advised request is executed.
|
||||
*/
|
||||
AFTER,
|
||||
/**
|
||||
* The advisor is called around the advised request.
|
||||
*/
|
||||
AROUND
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Builder for {@link AdvisorObservationContext}.
|
||||
*/
|
||||
@@ -236,12 +90,6 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
|
||||
private int order = 0;
|
||||
|
||||
private AdvisedRequest advisorRequest;
|
||||
|
||||
private Map<String, Object> advisorRequestContext;
|
||||
|
||||
private Map<String, Object> advisorResponseContext;
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
@@ -260,65 +108,8 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advisor type.
|
||||
* @param advisorType the advisor type
|
||||
* @return the builder
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorType(Type advisorType) {
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advised request.
|
||||
* @param advisedRequest the advised request
|
||||
* @return the builder
|
||||
* @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisedRequest(AdvisedRequest advisedRequest) {
|
||||
this.advisorRequest = advisedRequest;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advisor request context.
|
||||
* @param advisorRequestContext the advisor request context
|
||||
* @return the builder
|
||||
* @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorRequestContext(Map<String, Object> advisorRequestContext) {
|
||||
this.advisorRequestContext = advisorRequestContext;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advisor response context.
|
||||
* @param advisorResponseContext the advisor response context
|
||||
* @return the builder
|
||||
* @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorResponseContext(Map<String, Object> advisorResponseContext) {
|
||||
this.advisorResponseContext = advisorResponseContext;
|
||||
return this;
|
||||
}
|
||||
|
||||
public AdvisorObservationContext build() {
|
||||
if (chatClientRequest != null && advisorRequest != null) {
|
||||
throw new IllegalArgumentException(
|
||||
"ChatClientRequest and AdvisedRequest cannot be set at the same time");
|
||||
}
|
||||
else if (chatClientRequest != null) {
|
||||
return new AdvisorObservationContext(this.advisorName, this.chatClientRequest, this.order);
|
||||
}
|
||||
else {
|
||||
return new AdvisorObservationContext(this.advisorName, Type.AROUND, this.advisorRequest,
|
||||
this.advisorRequestContext, this.advisorResponseContext, this.order);
|
||||
}
|
||||
return new AdvisorObservationContext(this.advisorName, this.chatClientRequest, this.order);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -87,18 +87,6 @@ public enum AdvisorObservationDocumentation implements ObservationDocumentation
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Advisor type: Before, After or Around.
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
ADVISOR_TYPE {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.advisor.type";
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Advisor name.
|
||||
*/
|
||||
|
||||
@@ -69,8 +69,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
|
||||
@Override
|
||||
public KeyValues getLowCardinalityKeyValues(AdvisorObservationContext context) {
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
return KeyValues.of(aiOperationType(context), aiProvider(context), springAiKind(), advisorType(context),
|
||||
advisorName(context));
|
||||
return KeyValues.of(aiOperationType(context), aiProvider(context), springAiKind(), advisorName(context));
|
||||
}
|
||||
|
||||
protected KeyValue aiOperationType(AdvisorObservationContext context) {
|
||||
@@ -81,14 +80,6 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
|
||||
return KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER, AiProvider.SPRING_AI.value());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
protected KeyValue advisorType(AdvisorObservationContext context) {
|
||||
return KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE, context.getAdvisorType().name());
|
||||
}
|
||||
|
||||
protected KeyValue springAiKind() {
|
||||
return KeyValue.of(LowCardinalityKeyNames.SPRING_AI_KIND, SpringAiKind.ADVISOR.value());
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.observation;
|
||||
|
||||
import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationFilter;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.observation.tracing.TracingHelper;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An {@link ObservationFilter} to include the chat prompt content in the observation.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link ChatClientPromptContentObservationFilter}.
|
||||
*/
|
||||
@Deprecated
|
||||
public class ChatClientInputContentObservationFilter implements ObservationFilter {
|
||||
|
||||
@Override
|
||||
public Observation.Context map(Observation.Context context) {
|
||||
if (!(context instanceof ChatClientObservationContext chatClientObservationContext)) {
|
||||
return context;
|
||||
}
|
||||
chatClientSystemText(chatClientObservationContext);
|
||||
chatClientSystemParams(chatClientObservationContext);
|
||||
chatClientUserText(chatClientObservationContext);
|
||||
chatClientUserParams(chatClientObservationContext);
|
||||
|
||||
return chatClientObservationContext;
|
||||
}
|
||||
|
||||
protected void chatClientSystemText(ChatClientObservationContext context) {
|
||||
List<Message> messages = context.getRequest().prompt().getInstructions();
|
||||
if (CollectionUtils.isEmpty(messages)) {
|
||||
return;
|
||||
}
|
||||
|
||||
var systemMessage = messages.stream()
|
||||
.filter(message -> message instanceof SystemMessage)
|
||||
.reduce((first, second) -> second);
|
||||
if (systemMessage.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_TEXT,
|
||||
systemMessage.get().getText()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
protected void chatClientSystemParams(ChatClientObservationContext context) {
|
||||
if (!(context.getRequest()
|
||||
.context()
|
||||
.get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map<?, ?> systemParams)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(systemParams)) {
|
||||
return;
|
||||
}
|
||||
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_PARAM,
|
||||
TracingHelper.concatenateMaps((Map<String, Object>) systemParams)));
|
||||
}
|
||||
|
||||
protected void chatClientUserText(ChatClientObservationContext context) {
|
||||
List<Message> messages = context.getRequest().prompt().getInstructions();
|
||||
if (CollectionUtils.isEmpty(messages)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!(messages.get(messages.size() - 1) instanceof UserMessage userMessage)) {
|
||||
return;
|
||||
}
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT,
|
||||
userMessage.getText()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
protected void chatClientUserParams(ChatClientObservationContext context) {
|
||||
if (!(context.getRequest()
|
||||
.context()
|
||||
.get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map<?, ?> userParams)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(userParams)) {
|
||||
return;
|
||||
}
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_PARAMS,
|
||||
TracingHelper.concatenateMaps((Map<String, Object>) userParams)));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -78,12 +78,7 @@ public class ChatClientObservationContext extends Observation.Context {
|
||||
return this.stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated not used anymore. The format instructions are already included in the
|
||||
* ChatModelObservationContext.
|
||||
*/
|
||||
@Nullable
|
||||
@Deprecated
|
||||
public String getFormat() {
|
||||
if (this.request.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey()) instanceof String format) {
|
||||
return format;
|
||||
@@ -91,21 +86,13 @@ public class ChatClientObservationContext extends Observation.Context {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated not used anymore. The format instructions are already included in the
|
||||
* ChatModelObservationContext.
|
||||
*/
|
||||
@Deprecated
|
||||
public void setFormat(@Nullable String format) {
|
||||
this.request.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format);
|
||||
}
|
||||
|
||||
public static final class Builder {
|
||||
|
||||
private ChatClientRequest chatClientRequest;
|
||||
|
||||
private List<? extends Advisor> advisors = List.of();
|
||||
|
||||
@Nullable
|
||||
private String format;
|
||||
|
||||
private boolean isStream = false;
|
||||
@@ -118,17 +105,7 @@ public class ChatClientObservationContext extends Observation.Context {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated // use request(ChatClientRequest chatClientRequest)
|
||||
public Builder withRequest(ChatClientRequest chatClientRequest) {
|
||||
return request(chatClientRequest);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated not used anymore. The format instructions are already included in
|
||||
* the ChatModelObservationContext.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withFormat(String format) {
|
||||
public Builder format(@Nullable String format) {
|
||||
this.format = format;
|
||||
return this;
|
||||
}
|
||||
@@ -143,11 +120,6 @@ public class ChatClientObservationContext extends Observation.Context {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated // use stream(boolean isStream)
|
||||
public Builder withStream(boolean isStream) {
|
||||
return stream(isStream);
|
||||
}
|
||||
|
||||
public ChatClientObservationContext build() {
|
||||
if (StringUtils.hasText(format)) {
|
||||
this.chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format);
|
||||
|
||||
@@ -109,94 +109,6 @@ public enum ChatClientObservationDocumentation implements ObservationDocumentati
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Enabled tool function names.
|
||||
* @deprecated replaced by {@link #CHAT_CLIENT_TOOL_NAMES}
|
||||
*/
|
||||
@Deprecated
|
||||
CHAT_CLIENT_TOOL_FUNCTION_NAMES {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.chat.client.tool.function.names";
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* List of configured chat client function callbacks.
|
||||
* @deprecated replaced by {@link #CHAT_CLIENT_TOOL_NAMES}
|
||||
*/
|
||||
@Deprecated
|
||||
CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.chat.client.tool.function.callbacks";
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Map of advisor parameters.
|
||||
* @deprecated risk to expose sensitive information or break the instrumentation
|
||||
* since the advisor context map is used to pass arbitrary Java objects between
|
||||
* advisors and not necessarily serializable. The conversation ID, previously part
|
||||
* of this, is already included in the {@link #CHAT_CLIENT_CONVERSATION_ID}
|
||||
* method.
|
||||
*/
|
||||
@Deprecated
|
||||
CHAT_CLIENT_ADVISOR_PARAMS {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.chat.client.advisor.params";
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Chat client user text.
|
||||
* @deprecated replaced by {@link #PROMPT}
|
||||
*/
|
||||
@Deprecated
|
||||
CHAT_CLIENT_USER_TEXT {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.chat.client.user.text";
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Chat client user parameters.
|
||||
* @deprecated replaced by {@link #PROMPT}
|
||||
*/
|
||||
@Deprecated
|
||||
CHAT_CLIENT_USER_PARAMS {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.chat.client.user.params";
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Chat client system text.
|
||||
* @deprecated replaced by {@link #PROMPT}
|
||||
*/
|
||||
@Deprecated
|
||||
CHAT_CLIENT_SYSTEM_TEXT {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.chat.client.system.text";
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Chat client system parameters.
|
||||
* @deprecated replaced by {@link #PROMPT}
|
||||
*/
|
||||
@Deprecated
|
||||
CHAT_CLIENT_SYSTEM_PARAM {
|
||||
@Override
|
||||
public String asString() {
|
||||
return "spring.ai.chat.client.system.params";
|
||||
}
|
||||
},
|
||||
|
||||
// Content
|
||||
|
||||
/**
|
||||
|
||||
@@ -16,14 +16,11 @@
|
||||
|
||||
package org.springframework.ai.chat.client.observation;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
||||
import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.common.KeyValues;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
|
||||
@@ -97,12 +94,6 @@ public class DefaultChatClientObservationConvention implements ChatClientObserva
|
||||
keyValues = advisors(keyValues, context);
|
||||
keyValues = conversationId(keyValues, context);
|
||||
keyValues = tools(keyValues, context);
|
||||
// @deprecated remove before 1.0.0-RC1.
|
||||
keyValues = chatClientAdvisorParams(keyValues, context);
|
||||
// @deprecated remove before 1.0.0-RC1.
|
||||
keyValues = toolNames(keyValues, context);
|
||||
// @deprecated remove before 1.0.0-RC1.
|
||||
keyValues = toolCallbacks(keyValues, context);
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
@@ -123,7 +114,8 @@ public class DefaultChatClientObservationConvention implements ChatClientObserva
|
||||
var conversationIdValue = context.getRequest()
|
||||
.context()
|
||||
.get(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY);
|
||||
if (!(conversationIdValue instanceof String conversationId) || StringUtils.isEmpty(conversationId)) {
|
||||
|
||||
if (!(conversationIdValue instanceof String conversationId) || !StringUtils.hasText(conversationId)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
@@ -154,71 +146,4 @@ public class DefaultChatClientObservationConvention implements ChatClientObserva
|
||||
TracingHelper.concatenateStrings(toolNames.stream().sorted().toList()));
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated risk to expose sensitive information or break the instrumentation since
|
||||
* the advisor context map is used to pass arbitrary Java objects between advisors and
|
||||
* not necessarily serializable. The conversation ID, previously part of this, is
|
||||
* already included in the
|
||||
* {@link #conversationId(KeyValues, ChatClientObservationContext)} method.
|
||||
*/
|
||||
@Deprecated
|
||||
protected KeyValues chatClientAdvisorParams(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().context())) {
|
||||
return keyValues;
|
||||
}
|
||||
var chatClientContext = new HashMap<>(context.getRequest().context());
|
||||
Arrays.stream(ChatClientAttributes.values()).forEach(attribute -> chatClientContext.remove(attribute.getKey()));
|
||||
return keyValues.and(
|
||||
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(),
|
||||
TracingHelper.concatenateMaps(chatClientContext));
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated in favor of {@link #tools(KeyValues, ChatClientObservationContext)}
|
||||
*/
|
||||
@Deprecated
|
||||
protected KeyValues toolNames(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (context.getRequest().prompt().getOptions() == null) {
|
||||
return keyValues;
|
||||
}
|
||||
if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
var toolNames = options.getToolNames();
|
||||
if (CollectionUtils.isEmpty(toolNames)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
return keyValues.and(
|
||||
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(),
|
||||
TracingHelper.concatenateStrings(toolNames.stream().sorted().toList()));
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated in favor of {@link #tools(KeyValues, ChatClientObservationContext)}
|
||||
*/
|
||||
@Deprecated
|
||||
protected KeyValues toolCallbacks(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (context.getRequest().prompt().getOptions() == null) {
|
||||
return keyValues;
|
||||
}
|
||||
if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
var toolCallbacks = options.getToolCallbacks();
|
||||
if (CollectionUtils.isEmpty(toolCallbacks)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
var toolCallbackNames = toolCallbacks.stream()
|
||||
.map(toolCallback -> toolCallback.getToolDefinition().name())
|
||||
.sorted()
|
||||
.toList();
|
||||
return keyValues
|
||||
.and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS
|
||||
.asString(), TracingHelper.concatenateStrings(toolCallbackNames));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -24,9 +24,9 @@ import java.util.function.Consumer;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
@@ -42,31 +42,28 @@ import org.springframework.util.StringUtils;
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Alexandros Pappas
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class MessageAggregator {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MessageAggregator.class);
|
||||
|
||||
public Flux<AdvisedResponse> aggregateAdvisedResponse(Flux<AdvisedResponse> advisedResponses,
|
||||
Consumer<AdvisedResponse> aggregationHandler) {
|
||||
public Flux<ChatClientResponse> aggregateChatClientResponse(Flux<ChatClientResponse> chatClientResponses,
|
||||
Consumer<ChatClientResponse> aggregationHandler) {
|
||||
|
||||
AtomicReference<Map<String, Object>> adviseContext = new AtomicReference<>(new HashMap<>());
|
||||
|
||||
return new MessageAggregator().aggregate(advisedResponses.map(ar -> {
|
||||
adviseContext.get().putAll(ar.adviseContext());
|
||||
return ar.response();
|
||||
AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());
|
||||
|
||||
return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {
|
||||
context.get().putAll(chatClientResponse.context());
|
||||
return chatClientResponse.chatResponse();
|
||||
}), aggregatedChatResponse -> {
|
||||
|
||||
AdvisedResponse aggregatedAdvisedResponse = AdvisedResponse.builder()
|
||||
.response(aggregatedChatResponse)
|
||||
.adviseContext(adviseContext.get())
|
||||
ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()
|
||||
.chatResponse(aggregatedChatResponse)
|
||||
.context(context.get())
|
||||
.build();
|
||||
|
||||
aggregationHandler.accept(aggregatedAdvisedResponse);
|
||||
|
||||
}).map(cr -> new AdvisedResponse(cr, adviseContext.get()));
|
||||
aggregationHandler.accept(aggregatedChatClientResponse);
|
||||
}).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse).context(context.get()).build());
|
||||
}
|
||||
|
||||
public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
|
||||
|
||||
@@ -713,7 +713,7 @@ public class ChatClientTest {
|
||||
assertThat(content).isEqualTo("response");
|
||||
|
||||
assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(2);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(systemMessage.getText()).isEqualTo("instructions");
|
||||
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
|
||||
}
|
||||
@@ -747,7 +747,7 @@ public class ChatClientTest {
|
||||
assertThat(content).isEqualTo("response");
|
||||
|
||||
assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(2);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(systemMessage.getText()).isEqualTo("other instructions");
|
||||
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
|
||||
}
|
||||
@@ -769,7 +769,7 @@ public class ChatClientTest {
|
||||
assertThat(content).isEqualTo("response");
|
||||
|
||||
assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(2);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(systemMessage.getText()).isEqualTo("instructions");
|
||||
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
|
||||
}
|
||||
@@ -808,7 +808,7 @@ public class ChatClientTest {
|
||||
assertThat(content).isEqualTo("response");
|
||||
|
||||
assertThat(this.promptCaptor.getValue().getInstructions()).hasSize(4);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(2);
|
||||
var systemMessage = this.promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(systemMessage.getText()).isEqualTo("other instructions");
|
||||
assertThat(systemMessage.getMessageType()).isEqualTo(MessageType.SYSTEM);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,448 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.template.TemplateRenderer;
|
||||
import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.definition.ToolDefinition;
|
||||
import org.springframework.ai.tool.metadata.ToolMetadata;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link DefaultChatClientUtils}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class DefaultChatClientUtilsTests {
|
||||
|
||||
@Test
|
||||
void whenInputRequestIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> DefaultChatClientUtils.toChatClientRequest(null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("inputRequest cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenSystemTextIsProvidedThenSystemMessageIsAddedToPrompt() {
|
||||
String systemText = "System instructions";
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.system(systemText);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).isNotEmpty();
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(systemText);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenSystemTextWithParamsIsProvidedThenSystemMessageIsRenderedAndAddedToPrompt() {
|
||||
String systemText = "System instructions for {name}";
|
||||
Map<String, Object> systemParams = Map.of("name", "Spring AI");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.system(s -> s.text(systemText).params(systemParams));
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).isNotEmpty();
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("System instructions for Spring AI");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenMessagesAreProvidedThenTheyAreAddedToPrompt() {
|
||||
List<Message> messages = List.of(new SystemMessage("System message"), new UserMessage("User message"));
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.messages(messages);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).hasSize(2);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("System message");
|
||||
assertThat(result.prompt().getInstructions().get(1).getText()).isEqualTo("User message");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUserTextIsProvidedThenUserMessageIsAddedToPrompt() {
|
||||
String userText = "User question";
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.user(userText);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).isNotEmpty();
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(UserMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(userText);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUserTextWithParamsIsProvidedThenUserMessageIsRenderedAndAddedToPrompt() {
|
||||
String userText = "Question about {topic}";
|
||||
Map<String, Object> userParams = Map.of("topic", "Spring AI");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.user(s -> s.text(userText).params(userParams));
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).isNotEmpty();
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(UserMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("Question about Spring AI");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUserTextWithMediaIsProvidedThenUserMessageWithMediaIsAddedToPrompt() {
|
||||
String userText = "What's in this image?";
|
||||
Media media = mock(Media.class);
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.user(s -> s.text(userText).media(media));
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).isNotEmpty();
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(UserMessage.class);
|
||||
UserMessage userMessage = (UserMessage) result.prompt().getInstructions().get(0);
|
||||
assertThat(userMessage.getText()).isEqualTo(userText);
|
||||
assertThat(userMessage.getMedia()).contains(media);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenSystemTextAndSystemMessageAreProvidedThenSystemTextIsFirst() {
|
||||
String systemText = "System instructions";
|
||||
List<Message> messages = List.of(new SystemMessage("System message"));
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.system(systemText)
|
||||
.messages(messages);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).hasSize(2);
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo(systemText);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUserTextAndUserMessageAreProvidedThenUserTextIsLast() {
|
||||
String userText = "User question";
|
||||
List<Message> messages = List.of(new UserMessage("User message"));
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.user(userText)
|
||||
.messages(messages);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).hasSize(2);
|
||||
assertThat(result.prompt().getInstructions()).last().isInstanceOf(UserMessage.class);
|
||||
assertThat(result.prompt().getInstructions()).last().extracting(Message::getText).isEqualTo(userText);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolCallingChatOptionsIsProvidedThenToolNamesAreSet() {
|
||||
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build();
|
||||
List<String> toolNames = List.of("tool1", "tool2");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.options(chatOptions)
|
||||
.toolNames(toolNames.toArray(new String[0]));
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
|
||||
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
|
||||
assertThat(resultOptions).isNotNull();
|
||||
assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolCallingChatOptionsIsProvidedThenToolCallbacksAreSet() {
|
||||
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build();
|
||||
ToolCallback toolCallback = new TestToolCallback("tool1");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.options(chatOptions)
|
||||
.toolCallbacks(toolCallback);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
|
||||
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
|
||||
assertThat(resultOptions).isNotNull();
|
||||
assertThat(resultOptions.getToolCallbacks()).contains(toolCallback);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolCallingChatOptionsIsProvidedThenToolContextIsSet() {
|
||||
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build();
|
||||
Map<String, Object> toolContext = Map.of("key", "value");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.options(chatOptions)
|
||||
.toolContext(toolContext);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
|
||||
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
|
||||
assertThat(resultOptions).isNotNull();
|
||||
assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolNamesAndChatOptionsAreProvidedThenTheToolNamesOverride() {
|
||||
Set<String> toolNames1 = Set.of("toolA", "toolB");
|
||||
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolNames(toolNames1).build();
|
||||
List<String> toolNames2 = List.of("tool1", "tool2");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.options(chatOptions)
|
||||
.toolNames(toolNames2.toArray(new String[0]));
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
|
||||
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
|
||||
assertThat(resultOptions).isNotNull();
|
||||
assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolCallbacksAndChatOptionsAreProvidedThenTheToolCallbacksOverride() {
|
||||
ToolCallback toolCallback1 = new TestToolCallback("tool1");
|
||||
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolCallbacks(toolCallback1).build();
|
||||
ToolCallback toolCallback2 = new TestToolCallback("tool2");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.options(chatOptions)
|
||||
.toolCallbacks(toolCallback2);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
|
||||
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
|
||||
assertThat(resultOptions).isNotNull();
|
||||
assertThat(resultOptions.getToolCallbacks()).containsExactlyInAnyOrder(toolCallback2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolContextAndChatOptionsAreProvidedThenTheValuesAreMerged() {
|
||||
Map<String, Object> toolContext1 = Map.of("key1", "value1");
|
||||
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolContext(toolContext1).build();
|
||||
Map<String, Object> toolContext2 = Map.of("key2", "value2");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.options(chatOptions)
|
||||
.toolContext(toolContext2);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
|
||||
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
|
||||
assertThat(resultOptions).isNotNull();
|
||||
assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext1)
|
||||
.containsAllEntriesOf(toolContext2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAdvisorParamsAreProvidedThenTheyAreAddedToContext() {
|
||||
Map<String, Object> advisorParams = Map.of("key1", "value1", "key2", "value2");
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.advisors(a -> a.params(advisorParams));
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.context()).containsAllEntriesOf(advisorParams);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenCustomTemplateRendererIsProvidedThenItIsUsedForRendering() {
|
||||
String systemText = "Instructions <name>";
|
||||
Map<String, Object> systemParams = Map.of("name", "Spring AI");
|
||||
TemplateRenderer customRenderer = StTemplateRenderer.builder()
|
||||
.startDelimiterToken('<')
|
||||
.endDelimiterToken('>')
|
||||
.build();
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.system(s -> s.text(systemText).params(systemParams))
|
||||
.templateRenderer(customRenderer);
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
assertThat(result.prompt().getInstructions()).isNotEmpty();
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("Instructions Spring AI");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAllComponentsAreProvidedThenCompleteRequestIsCreated() {
|
||||
String systemText = "System instructions for {name}";
|
||||
Map<String, Object> systemParams = Map.of("name", "Spring AI");
|
||||
|
||||
String userText = "Question about {topic}";
|
||||
Map<String, Object> userParams = Map.of("topic", "Spring AI");
|
||||
Media media = mock(Media.class);
|
||||
|
||||
List<Message> messages = List.of(new UserMessage("Intermediate message"));
|
||||
|
||||
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().build();
|
||||
List<String> toolNames = List.of("tool1", "tool2");
|
||||
ToolCallback toolCallback = new TestToolCallback("tool3");
|
||||
Map<String, Object> toolContext = Map.of("toolKey", "toolValue");
|
||||
|
||||
Map<String, Object> advisorParams = Map.of("advisorKey", "advisorValue");
|
||||
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
DefaultChatClient.DefaultChatClientRequestSpec inputRequest = (DefaultChatClient.DefaultChatClientRequestSpec) ChatClient
|
||||
.create(chatModel)
|
||||
.prompt()
|
||||
.system(s -> s.text(systemText).params(systemParams))
|
||||
.user(u -> u.text(userText).params(userParams).media(media))
|
||||
.messages(messages)
|
||||
.toolNames(toolNames.toArray(new String[0]))
|
||||
.toolCallbacks(toolCallback)
|
||||
.toolContext(toolContext)
|
||||
.options(chatOptions)
|
||||
.advisors(a -> a.params(advisorParams));
|
||||
|
||||
ChatClientRequest result = DefaultChatClientUtils.toChatClientRequest(inputRequest);
|
||||
|
||||
assertThat(result).isNotNull();
|
||||
|
||||
assertThat(result.prompt().getInstructions()).hasSize(3);
|
||||
assertThat(result.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(0).getText()).isEqualTo("System instructions for Spring AI");
|
||||
assertThat(result.prompt().getInstructions().get(1).getText()).isEqualTo("Intermediate message");
|
||||
assertThat(result.prompt().getInstructions().get(2)).isInstanceOf(UserMessage.class);
|
||||
assertThat(result.prompt().getInstructions().get(2).getText()).isEqualTo("Question about Spring AI");
|
||||
UserMessage userMessage = (UserMessage) result.prompt().getInstructions().get(2);
|
||||
assertThat(userMessage.getMedia()).contains(media);
|
||||
|
||||
assertThat(result.prompt().getOptions()).isInstanceOf(ToolCallingChatOptions.class);
|
||||
ToolCallingChatOptions resultOptions = (ToolCallingChatOptions) result.prompt().getOptions();
|
||||
assertThat(resultOptions).isNotNull();
|
||||
assertThat(resultOptions.getToolNames()).containsExactlyInAnyOrderElementsOf(toolNames);
|
||||
assertThat(resultOptions.getToolCallbacks()).contains(toolCallback);
|
||||
assertThat(resultOptions.getToolContext()).containsAllEntriesOf(toolContext);
|
||||
|
||||
assertThat(result.context()).containsAllEntriesOf(advisorParams);
|
||||
}
|
||||
|
||||
static class TestToolCallback implements ToolCallback {
|
||||
|
||||
private final ToolDefinition toolDefinition;
|
||||
|
||||
private final ToolMetadata toolMetadata;
|
||||
|
||||
TestToolCallback(String name) {
|
||||
this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build();
|
||||
this.toolMetadata = ToolMetadata.builder().build();
|
||||
}
|
||||
|
||||
TestToolCallback(String name, boolean returnDirect) {
|
||||
this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build();
|
||||
this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ToolDefinition getToolDefinition() {
|
||||
return this.toolDefinition;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ToolMetadata getToolMetadata() {
|
||||
return this.toolMetadata;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String call(String toolInput) {
|
||||
return "Mission accomplished!";
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2025-2025 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -14,83 +14,84 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import java.util.List;
|
||||
package org.springframework.ai.chat.client.advisor;
|
||||
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link AdvisedResponseStreamUtils}.
|
||||
* Unit tests for {@link AdvisorUtils}.
|
||||
*
|
||||
* @author ghdcksgml1
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class AdvisedResponseStreamUtilsTest {
|
||||
class AdvisorUtilsTests {
|
||||
|
||||
@Nested
|
||||
class OnFinishReason {
|
||||
|
||||
@Test
|
||||
void whenChatResponseIsNullThenReturnFalse() {
|
||||
AdvisedResponse response = mock(AdvisedResponse.class);
|
||||
given(response.response()).willReturn(null);
|
||||
ChatClientResponse chatClientResponse = mock(ChatClientResponse.class);
|
||||
given(chatClientResponse.chatResponse()).willReturn(null);
|
||||
|
||||
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
|
||||
boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse);
|
||||
|
||||
assertFalse(result);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenChatResponseResultsIsNullThenReturnFalse() {
|
||||
AdvisedResponse response = mock(AdvisedResponse.class);
|
||||
ChatClientResponse chatClientResponse = mock(ChatClientResponse.class);
|
||||
ChatResponse chatResponse = mock(ChatResponse.class);
|
||||
|
||||
given(chatResponse.getResults()).willReturn(null);
|
||||
given(response.response()).willReturn(chatResponse);
|
||||
given(chatClientResponse.chatResponse()).willReturn(chatResponse);
|
||||
|
||||
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
|
||||
boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse);
|
||||
|
||||
assertFalse(result);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenChatIsRunningThenReturnFalse() {
|
||||
AdvisedResponse response = mock(AdvisedResponse.class);
|
||||
ChatClientResponse chatClientResponse = mock(ChatClientResponse.class);
|
||||
ChatResponse chatResponse = mock(ChatResponse.class);
|
||||
|
||||
Generation generation = new Generation(new AssistantMessage("running.."), ChatGenerationMetadata.NULL);
|
||||
|
||||
given(chatResponse.getResults()).willReturn(List.of(generation));
|
||||
given(response.response()).willReturn(chatResponse);
|
||||
given(chatClientResponse.chatResponse()).willReturn(chatResponse);
|
||||
|
||||
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
|
||||
boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse);
|
||||
|
||||
assertFalse(result);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenChatIsStopThenReturnTrue() {
|
||||
AdvisedResponse response = mock(AdvisedResponse.class);
|
||||
ChatClientResponse chatClientResponse = mock(ChatClientResponse.class);
|
||||
ChatResponse chatResponse = mock(ChatResponse.class);
|
||||
|
||||
Generation generation = new Generation(new AssistantMessage("finish."),
|
||||
ChatGenerationMetadata.builder().finishReason("STOP").build());
|
||||
|
||||
given(chatResponse.getResults()).willReturn(List.of(generation));
|
||||
given(response.response()).willReturn(chatResponse);
|
||||
given(chatClientResponse.chatResponse()).willReturn(chatResponse);
|
||||
|
||||
boolean result = AdvisedResponseStreamUtils.onFinishReason().test(response);
|
||||
boolean result = AdvisorUtils.onFinishReason().test(chatClientResponse);
|
||||
|
||||
assertTrue(result);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -30,12 +30,12 @@ import org.mockito.junit.jupiter.MockitoExtension;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
@@ -84,8 +84,8 @@ public class AdvisorsTests {
|
||||
assertThat(content).isEqualTo("Hello John");
|
||||
|
||||
// AROUND
|
||||
assertThat(mockAroundAdvisor1.advisedResponse.response()).isNotNull();
|
||||
assertThat(mockAroundAdvisor1.advisedResponse.adviseContext()).containsEntry("key1", "value1")
|
||||
assertThat(mockAroundAdvisor1.chatClientResponse.chatResponse()).isNotNull();
|
||||
assertThat(mockAroundAdvisor1.chatClientResponse.context()).containsEntry("key1", "value1")
|
||||
.containsEntry("key2", "value2")
|
||||
.containsEntry("aroundCallBeforeAdvisor1", "AROUND_CALL_BEFORE Advisor1")
|
||||
.containsEntry("aroundCallAfterAdvisor1", "AROUND_CALL_AFTER Advisor1")
|
||||
@@ -126,10 +126,10 @@ public class AdvisorsTests {
|
||||
assertThat(content).isEqualTo("Hello John");
|
||||
|
||||
// AROUND
|
||||
assertThat(mockAroundAdvisor1.aroundAdvisedResponses).isNotEmpty();
|
||||
assertThat(mockAroundAdvisor1.advisedChatClientResponses).isNotEmpty();
|
||||
|
||||
mockAroundAdvisor1.aroundAdvisedResponses.stream()
|
||||
.forEach(advisedResponse -> assertThat(advisedResponse.adviseContext()).containsEntry("key1", "value1")
|
||||
mockAroundAdvisor1.advisedChatClientResponses.stream()
|
||||
.forEach(chatClientResponse -> assertThat(chatClientResponse.context()).containsEntry("key1", "value1")
|
||||
.containsEntry("key2", "value2")
|
||||
.containsEntry("aroundStreamBeforeAdvisor1", "AROUND_STREAM_BEFORE Advisor1")
|
||||
.containsEntry("aroundStreamAfterAdvisor1", "AROUND_STREAM_AFTER Advisor1")
|
||||
@@ -142,17 +142,17 @@ public class AdvisorsTests {
|
||||
verify(this.chatModel).stream(this.promptCaptor.capture());
|
||||
}
|
||||
|
||||
public class MockAroundAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
|
||||
public class MockAroundAdvisor implements CallAdvisor, StreamAdvisor {
|
||||
|
||||
private final String name;
|
||||
|
||||
private final int order;
|
||||
|
||||
public AdvisedRequest advisedRequest;
|
||||
public ChatClientRequest chatClientRequest;
|
||||
|
||||
public AdvisedResponse advisedResponse;
|
||||
public ChatClientResponse chatClientResponse;
|
||||
|
||||
public List<AdvisedResponse> aroundAdvisedResponses = new ArrayList<>();
|
||||
public List<ChatClientResponse> advisedChatClientResponses = new ArrayList<>();
|
||||
|
||||
public MockAroundAdvisor(String name, int order) {
|
||||
this.name = name;
|
||||
@@ -170,45 +170,38 @@ public class AdvisorsTests {
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAdvisorChain callAdvisorChain) {
|
||||
this.chatClientRequest = chatClientRequest.mutate()
|
||||
.context(Map.of("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName(), "lastBefore",
|
||||
getName()))
|
||||
.build();
|
||||
|
||||
this.advisedRequest = advisedRequest.updateContext(context -> {
|
||||
context.put("aroundCallBefore" + getName(), "AROUND_CALL_BEFORE " + getName());
|
||||
context.put("lastBefore", getName());
|
||||
return context;
|
||||
});
|
||||
var chatClientResponse = callAdvisorChain.nextCall(this.chatClientRequest);
|
||||
|
||||
this.advisedResponse = chain.nextAroundCall(this.advisedRequest);
|
||||
AdvisedResponse advisedResponse = this.advisedResponse;
|
||||
this.chatClientResponse = chatClientResponse.mutate()
|
||||
.context(
|
||||
Map.of("aroundCallAfter" + getName(), "AROUND_CALL_AFTER " + getName(), "lastAfter", getName()))
|
||||
.build();
|
||||
|
||||
this.advisedResponse = advisedResponse.updateContext(context -> {
|
||||
context.put("aroundCallAfter" + this.name, "AROUND_CALL_AFTER " + this.name);
|
||||
context.put("lastAfter", this.name);
|
||||
return context;
|
||||
});
|
||||
|
||||
return this.advisedResponse;
|
||||
return this.chatClientResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
|
||||
StreamAdvisorChain streamAdvisorChain) {
|
||||
this.chatClientRequest = chatClientRequest.mutate()
|
||||
.context(Map.of("aroundStreamBefore" + getName(), "AROUND_STREAM_BEFORE " + getName(), "lastBefore",
|
||||
getName()))
|
||||
.build();
|
||||
|
||||
this.advisedRequest = advisedRequest.updateContext(context -> {
|
||||
context.put("aroundStreamBefore" + this.name, "AROUND_STREAM_BEFORE " + this.name);
|
||||
context.put("lastBefore", this.name);
|
||||
return context;
|
||||
});
|
||||
|
||||
Flux<AdvisedResponse> advisedResponseStream = chain.nextAroundStream(this.advisedRequest);
|
||||
|
||||
return advisedResponseStream.map(advisedResponse -> {
|
||||
return advisedResponse.updateContext(context -> {
|
||||
context.put("aroundStreamAfter" + this.name, "AROUND_STREAM_AFTER " + this.name);
|
||||
context.put("lastAfter", this.name);
|
||||
return context;
|
||||
});
|
||||
}).doOnNext(ar -> this.aroundAdvisedResponses.add(ar));
|
||||
Flux<ChatClientResponse> chatClientResponseFlux = streamAdvisorChain.nextStream(this.chatClientRequest);
|
||||
|
||||
return chatClientResponseFlux
|
||||
.map(chatClientResponse -> chatClientResponse.mutate()
|
||||
.context(Map.of("aroundStreamAfter" + getName(), "AROUND_STREAM_AFTER " + getName(), "lastAfter",
|
||||
getName()))
|
||||
.build())
|
||||
.doOnNext(ar -> this.advisedChatClientResponses.add(ar));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ public class SimpleLoggerAdvisorTests {
|
||||
UserMessage userMessage = (UserMessage) this.promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(userMessage.getText()).isEqualToIgnoringWhitespace("Please answer my question XYZ");
|
||||
|
||||
assertThat(output.getOut()).contains("request: AdvisedRequest", "userText=Please answer my question XYZ");
|
||||
assertThat(output.getOut()).contains("request: ChatClientRequest", "Please answer my question XYZ");
|
||||
assertThat(output.getOut()).contains("response:", "finishReason");
|
||||
}
|
||||
|
||||
|
||||
@@ -1,279 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.template.TemplateRenderer;
|
||||
import org.springframework.ai.template.st.StTemplateRenderer;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link AdvisedRequest}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class AdvisedRequestTests {
|
||||
|
||||
@Test
|
||||
void buildAdvisedRequest() {
|
||||
AdvisedRequest request = new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of());
|
||||
assertThat(request).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenChatModelIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(null, "user", null, null, List.of(), List.of(), List.of(),
|
||||
List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("chatModel cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUserTextIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage(
|
||||
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUserTextIsEmptyThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage(
|
||||
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenMediaIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, null, List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("media cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenFunctionNamesIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), null,
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("toolNames cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolCallbacksIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
null, List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("toolCallbacks cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenMessagesIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), null, Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("messages cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUserParamsIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), null, Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("userParams cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenSystemParamsIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), null, List.of(), Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("systemParams cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAdvisorsIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), null, Map.of(), Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("advisors cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAdvisorParamsIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), null, Map.of(), Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("advisorParams cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAdviseContextIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), null, Map.of()))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("adviseContext cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenToolContextIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "user", null, null, List.of(), List.of(),
|
||||
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("toolContext cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenConvertToAndFromChatClientRequestWithDefaultTemplateRenderer() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
ChatOptions chatOptions = ToolCallingChatOptions.builder().build();
|
||||
List<Message> messages = List.of(mock(UserMessage.class));
|
||||
SystemMessage systemMessage = new SystemMessage("Instructions {key}");
|
||||
UserMessage userMessage = UserMessage.builder().text("Question {key}").media(mock(Media.class)).build();
|
||||
Map<String, Object> systemParams = Map.of("key", "value");
|
||||
Map<String, Object> userParams = Map.of("key", "value");
|
||||
List<String> toolNames = List.of("tool1", "tool2");
|
||||
ToolCallback toolCallback = mock(ToolCallback.class);
|
||||
Map<String, Object> toolContext = Map.of("key", "value");
|
||||
List<Advisor> advisors = List.of(mock(Advisor.class));
|
||||
Map<String, Object> advisorContext = Map.of("key", "value");
|
||||
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.builder()
|
||||
.chatModel(chatModel)
|
||||
.chatOptions(chatOptions)
|
||||
.messages(messages)
|
||||
.systemText(systemMessage.getText())
|
||||
.systemParams(systemParams)
|
||||
.userText(userMessage.getText())
|
||||
.userParams(userParams)
|
||||
.media(userMessage.getMedia())
|
||||
.toolNames(toolNames)
|
||||
.functionCallbacks(List.of(toolCallback))
|
||||
.toolContext(toolContext)
|
||||
.advisors(advisors)
|
||||
.adviseContext(advisorContext)
|
||||
.build();
|
||||
|
||||
ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest();
|
||||
|
||||
assertThat(chatClientRequest.context().get(ChatClientAttributes.CHAT_MODEL.getKey())).isEqualTo(chatModel);
|
||||
assertThat(chatClientRequest.prompt().getOptions()).isEqualTo(chatOptions);
|
||||
assertThat(chatClientRequest.prompt().getInstructions()).hasSize(3);
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(0)).isEqualTo(messages.get(0));
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(1).getText()).isEqualTo("Instructions value");
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(2).getText()).isEqualTo("Question value");
|
||||
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolNames())
|
||||
.containsAll(toolNames);
|
||||
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolCallbacks())
|
||||
.contains(toolCallback);
|
||||
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolContext())
|
||||
.containsAllEntriesOf(toolContext);
|
||||
assertThat((List<Advisor>) chatClientRequest.context().get(ChatClientAttributes.ADVISORS.getKey()))
|
||||
.containsAll(advisors);
|
||||
assertThat(chatClientRequest.context()).containsAllEntriesOf(advisorContext);
|
||||
|
||||
AdvisedRequest convertedAdvisedRequest = AdvisedRequest.from(chatClientRequest);
|
||||
assertThat(convertedAdvisedRequest.toPrompt()).isEqualTo(chatClientRequest.prompt());
|
||||
assertThat(convertedAdvisedRequest.adviseContext()).containsAllEntriesOf(chatClientRequest.context());
|
||||
assertThat(chatClientRequest.context().get(ChatClientAttributes.USER_PARAMS.getKey())).isEqualTo(userParams);
|
||||
assertThat(chatClientRequest.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey()))
|
||||
.isEqualTo(systemParams);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenConvertToAndFromChatClientRequestWithCustomTemplateRenderer() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
ChatOptions chatOptions = ToolCallingChatOptions.builder().build();
|
||||
SystemMessage systemMessage = new SystemMessage("Instructions <name>");
|
||||
UserMessage userMessage = UserMessage.builder().text("Question <name>").media(mock(Media.class)).build();
|
||||
Map<String, Object> systemParams = Map.of("name", "Spring AI");
|
||||
Map<String, Object> userParams = Map.of("name", "Spring AI");
|
||||
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.builder()
|
||||
.chatModel(chatModel)
|
||||
.chatOptions(chatOptions)
|
||||
.systemText(systemMessage.getText())
|
||||
.systemParams(systemParams)
|
||||
.userText(userMessage.getText())
|
||||
.userParams(userParams)
|
||||
.media(userMessage.getMedia())
|
||||
.build();
|
||||
|
||||
TemplateRenderer customRenderer = StTemplateRenderer.builder()
|
||||
.startDelimiterToken('<')
|
||||
.endDelimiterToken('>')
|
||||
.build();
|
||||
ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest(customRenderer);
|
||||
|
||||
assertThat(chatClientRequest.prompt().getInstructions()).hasSize(2);
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(1)).isInstanceOf(UserMessage.class);
|
||||
assertThat(chatClientRequest.context().get(ChatClientAttributes.USER_PARAMS.getKey())).isEqualTo(userParams);
|
||||
assertThat(chatClientRequest.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey()))
|
||||
.isEqualTo(systemParams);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUsingToPromptWithCustomTemplateRenderer() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
SystemMessage systemMessage = new SystemMessage("Instructions <name>");
|
||||
UserMessage userMessage = UserMessage.builder().text("Question <name>").media(mock(Media.class)).build();
|
||||
Map<String, Object> systemParams = Map.of("name", "Spring AI");
|
||||
Map<String, Object> userParams = Map.of("name", "Spring AI");
|
||||
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.builder()
|
||||
.chatModel(chatModel)
|
||||
.systemText(systemMessage.getText())
|
||||
.systemParams(systemParams)
|
||||
.userText(userMessage.getText())
|
||||
.userParams(userParams)
|
||||
.media(userMessage.getMedia())
|
||||
.build();
|
||||
|
||||
TemplateRenderer customRenderer = StTemplateRenderer.builder()
|
||||
.startDelimiterToken('<')
|
||||
.endDelimiterToken('>')
|
||||
.build();
|
||||
var prompt = advisedRequest.toPrompt(customRenderer);
|
||||
|
||||
assertThat(prompt.getInstructions()).hasSize(2);
|
||||
assertThat(prompt.getInstructions().get(0).getText()).isEqualTo("Instructions Spring AI");
|
||||
assertThat(prompt.getInstructions().get(1).getText()).isEqualTo("Question Spring AI");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link AdvisedResponse}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class AdvisedResponseTests {
|
||||
|
||||
@Test
|
||||
void buildAdvisedResponse() {
|
||||
AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of());
|
||||
assertThat(advisedResponse).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAdviseContextIsNullThenThrows() {
|
||||
assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("adviseContext cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAdviseContextKeysIsNullThenThrows() {
|
||||
Map<String, Object> adviseContext = new HashMap<>();
|
||||
adviseContext.put(null, "value");
|
||||
assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), adviseContext))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("adviseContext keys cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenAdviseContextValuesIsNullThenThrows() {
|
||||
Map<String, Object> adviseContext = new HashMap<>();
|
||||
adviseContext.put("key", null);
|
||||
assertThatThrownBy(() -> new AdvisedResponse(mock(ChatResponse.class), adviseContext))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("adviseContext values cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenBuildFromNullAdvisedResponseThenThrows() {
|
||||
assertThatThrownBy(() -> AdvisedResponse.from((AdvisedResponse) null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("advisedResponse cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildFromAdvisedResponse() {
|
||||
AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of());
|
||||
AdvisedResponse.Builder builder = AdvisedResponse.from(advisedResponse);
|
||||
assertThat(builder).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenUpdateFromNullContextThenThrows() {
|
||||
AdvisedResponse advisedResponse = new AdvisedResponse(mock(ChatResponse.class), Map.of());
|
||||
assertThatThrownBy(() -> advisedResponse.updateContext(null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("contextTransform cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenConvertToAndFromChatClientResponse() {
|
||||
ChatResponse chatResponse = mock(ChatResponse.class);
|
||||
Map<String, Object> context = Map.of("key", "value");
|
||||
AdvisedResponse advisedResponse = new AdvisedResponse(chatResponse, context);
|
||||
|
||||
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
|
||||
|
||||
AdvisedResponse newAdvisedResponse = AdvisedResponse.from(chatClientResponse);
|
||||
assertThat(newAdvisedResponse).isEqualTo(advisedResponse);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,12 +18,10 @@ package org.springframework.ai.chat.client.advisor.observation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link AdvisorObservationContext}.
|
||||
@@ -36,6 +34,7 @@ class AdvisorObservationContextTests {
|
||||
@Test
|
||||
void whenMandatoryOptionsThenReturn() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build())
|
||||
.advisorName("AdvisorName")
|
||||
.build();
|
||||
|
||||
@@ -44,19 +43,17 @@ class AdvisorObservationContextTests {
|
||||
|
||||
@Test
|
||||
void missingAdvisorName() {
|
||||
assertThatThrownBy(() -> AdvisorObservationContext.builder().build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
assertThatThrownBy(() -> AdvisorObservationContext.builder()
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build())
|
||||
.build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("advisorName cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenBuilderWithAdvisedRequestThenReturn() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("AdvisorName")
|
||||
.advisedRequest(mock(AdvisedRequest.class))
|
||||
.build();
|
||||
|
||||
assertThat(observationContext).isNotNull();
|
||||
void missingChatClientRequest() {
|
||||
assertThatThrownBy(() -> AdvisorObservationContext.builder().advisorName("AdvisorName").build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("chatClientRequest cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -69,13 +66,4 @@ class AdvisorObservationContextTests {
|
||||
assertThat(observationContext).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void missingBuilderWithBothRequestsThenThrow() {
|
||||
assertThatThrownBy(() -> AdvisorObservationContext.builder()
|
||||
.advisedRequest(mock(AdvisedRequest.class))
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build())
|
||||
.build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("ChatClientRequest and AdvisedRequest cannot be set at the same time");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,8 +20,10 @@ import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.observation.Observation;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.HighCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation.LowCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.observation.conventions.AiOperationType;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.ai.observation.conventions.SpringAiKind;
|
||||
@@ -46,6 +48,7 @@ class DefaultAdvisorObservationConventionTests {
|
||||
@Test
|
||||
void contextualName() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build())
|
||||
.advisorName("MyName")
|
||||
.build();
|
||||
assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("my_name");
|
||||
@@ -54,6 +57,7 @@ class DefaultAdvisorObservationConventionTests {
|
||||
@Test
|
||||
void supportsAdvisorObservationContext() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build())
|
||||
.advisorName("MyName")
|
||||
.build();
|
||||
assertThat(this.observationConvention.supportsContext(observationContext)).isTrue();
|
||||
@@ -63,6 +67,7 @@ class DefaultAdvisorObservationConventionTests {
|
||||
@Test
|
||||
void shouldHaveLowCardinalityKeyValuesWhenDefined() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build())
|
||||
.advisorName("MyName")
|
||||
.build();
|
||||
assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains(
|
||||
@@ -75,6 +80,7 @@ class DefaultAdvisorObservationConventionTests {
|
||||
@Test
|
||||
void shouldHaveKeyValuesWhenDefinedAndResponse() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt("Hello")).build())
|
||||
.advisorName("MyName")
|
||||
.order(678)
|
||||
.build();
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.observation;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.observation.Observation;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link ChatClientInputContentObservationFilter}.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class ChatClientInputContentObservationFilterTests {
|
||||
|
||||
private final ChatClientInputContentObservationFilter observationFilter = new ChatClientInputContentObservationFilter();
|
||||
|
||||
@Mock
|
||||
ChatModel chatModel;
|
||||
|
||||
@Test
|
||||
void whenNotSupportedObservationContextThenReturnOriginalContext() {
|
||||
var expectedContext = new Observation.Context();
|
||||
var actualContext = this.observationFilter.map(expectedContext);
|
||||
|
||||
assertThat(actualContext).isEqualTo(expectedContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenEmptyInputContentThenReturnOriginalContext() {
|
||||
var request = ChatClientRequest.builder().prompt(new Prompt()).build();
|
||||
|
||||
var expectedContext = ChatClientObservationContext.builder().request(request).build();
|
||||
|
||||
var actualContext = this.observationFilter.map(expectedContext);
|
||||
|
||||
assertThat(actualContext).isEqualTo(expectedContext);
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenWithTextThenAugmentContext() {
|
||||
var request = ChatClientRequest.builder()
|
||||
.prompt(new Prompt(new SystemMessage("sample system text"), new UserMessage("sample user text")))
|
||||
.context(ChatClientAttributes.USER_PARAMS.getKey(), Map.of("up1", "upv1"))
|
||||
.context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), Map.of("sp1", "sp1v"))
|
||||
.build();
|
||||
|
||||
var originalContext = ChatClientObservationContext.builder().request(request).build();
|
||||
|
||||
var augmentedContext = this.observationFilter.map(originalContext);
|
||||
|
||||
assertThat(augmentedContext.getHighCardinalityKeyValues())
|
||||
.contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT.asString(), "sample user text"));
|
||||
assertThat(augmentedContext.getHighCardinalityKeyValues())
|
||||
.contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_USER_PARAMS.asString(), "[\"up1\":\"upv1\"]"));
|
||||
assertThat(augmentedContext.getHighCardinalityKeyValues())
|
||||
.contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_TEXT.asString(), "sample system text"));
|
||||
assertThat(augmentedContext.getHighCardinalityKeyValues())
|
||||
.contains(KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_PARAM.asString(), "[\"sp1\":\"sp1v\"]"));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -19,7 +19,6 @@ package org.springframework.ai.chat.client.observation;
|
||||
import java.util.List;
|
||||
|
||||
import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.common.KeyValues;
|
||||
import io.micrometer.observation.Observation;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -27,13 +26,11 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
@@ -62,8 +59,8 @@ class DefaultChatClientObservationConventionTests {
|
||||
|
||||
ChatClientRequest request;
|
||||
|
||||
static CallAroundAdvisor dummyAdvisor(String name) {
|
||||
return new CallAroundAdvisor() {
|
||||
static CallAdvisor dummyAdvisor(String name) {
|
||||
return new CallAdvisor() {
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
@@ -76,7 +73,8 @@ class DefaultChatClientObservationConventionTests {
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest,
|
||||
CallAdvisorChain callAdvisorChain) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -156,7 +154,7 @@ class DefaultChatClientObservationConventionTests {
|
||||
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.request(request)
|
||||
.withFormat("json")
|
||||
.format("json")
|
||||
.advisors(List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")))
|
||||
.stream(true)
|
||||
.build();
|
||||
@@ -166,33 +164,7 @@ class DefaultChatClientObservationConventionTests {
|
||||
["advisor1", "advisor2"]"""),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_CONVERSATION_ID.asString(), "007"),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_NAMES.asString(), """
|
||||
["tool1", "tool2", "toolCallback1", "toolCallback2"]"""),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(), """
|
||||
["chat_memory_conversation_id":"007"]"""),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(), """
|
||||
["tool1", "tool2"]"""),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS.asString(), """
|
||||
["toolCallback1", "toolCallback2"]"""));
|
||||
}
|
||||
|
||||
@Test
|
||||
void entriesInAdvisorContextAreNotRemoved() {
|
||||
var request = ChatClientRequest.builder()
|
||||
.prompt(new Prompt(""))
|
||||
.context("advParam1", "advisorParam1Value")
|
||||
.context(ChatClientAttributes.ADVISORS.getKey(),
|
||||
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")))
|
||||
.build();
|
||||
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.request(request)
|
||||
.build();
|
||||
|
||||
assertThat(observationContext.getRequest().context()).hasSize(2);
|
||||
|
||||
this.observationConvention.chatClientAdvisorParams(KeyValues.empty(), observationContext);
|
||||
|
||||
assertThat(observationContext.getRequest().context()).hasSize(2);
|
||||
["tool1", "tool2", "toolCallback1", "toolCallback2"]"""));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.prompt;
|
||||
|
||||
public class ChatTests {
|
||||
|
||||
// @Test
|
||||
// void testChat() {
|
||||
//
|
||||
// String customerStyle = "American English in a calm and respectful tone";
|
||||
// String customerEmail = "Arrr, I be fuming that me blender lid "
|
||||
// + "flew off and splattered me kitchen walls "
|
||||
// + "with smoothie! And to make matters worse, "
|
||||
// + "the warranty don't cover the cost of "
|
||||
// + "cleaning up me kitchen. I need yer help "
|
||||
// + "right now, matey!";
|
||||
// ChatOpenAi chatOpenAi = new ChatOpenAi();
|
||||
// chatOpenAi
|
||||
//
|
||||
// }
|
||||
|
||||
}
|
||||
@@ -29,6 +29,13 @@ For details, refer to:
|
||||
[[upgrading-to-1-0-0-RC1]]
|
||||
== Upgrading to 1.0.0-RC1
|
||||
|
||||
=== Chat Client And Advisors
|
||||
|
||||
* When building a `Prompt` from the ChatClient input, the `SystemMessage` built from `systemText()` is now placed first in the message list. Before, it was put last, resulting in errors with several model providers.
|
||||
* In `AbstractChatMemoryAdvisor`, the `doNextWithProtectFromBlockingBefore()` protected method has been changed from accepting the old `AdvisedRequest` to the new `ChatClientRequest`. It’s a breaking change since the alternative was not part of M8.
|
||||
* `MessageAggregator` has a new method to aggregate messages from `ChatClientRequest`. The previous method aggregating messages from the old `AdvisedRequest` has been removed, since it was already marked as deprecated in M8.
|
||||
* In `SimpleLoggerAdvisor`, the `requestToString` input argument needs to be updated to use `ChatClientRequest`. It’s a breaking change since the alternative was not part of M8 yet. Same thing about the constructor.
|
||||
|
||||
=== Breaking Changes
|
||||
The Watson AI model was removed as it was based on the older text generation that is considered outdated as there is a new chat generation model available.
|
||||
Hopefully Watson will reappear in a future version of Spring AI
|
||||
|
||||
@@ -100,6 +100,20 @@ public class Prompt implements ModelRequest<List<Message>> {
|
||||
return this.messages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the first system message in the prompt. If no system message is found, an empty
|
||||
* SystemMessage is returned.
|
||||
*/
|
||||
public SystemMessage getSystemMessage() {
|
||||
for (int i = 0; i <= this.messages.size() - 1; i++) {
|
||||
Message message = this.messages.get(i);
|
||||
if (message instanceof SystemMessage systemMessage) {
|
||||
return systemMessage;
|
||||
}
|
||||
}
|
||||
return new SystemMessage("");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the last user message in the prompt. If no user message is found, an empty
|
||||
* UserMessage is returned.
|
||||
@@ -165,11 +179,44 @@ public class Prompt implements ModelRequest<List<Message>> {
|
||||
}
|
||||
|
||||
/**
|
||||
* @param userMessageAugmenter the function to augment the last user message.
|
||||
* @return a new prompt instance with the augmented user message.
|
||||
* Augments the first system message in the prompt with the provided function. If no
|
||||
* system message is found, a new one is created with the provided text.
|
||||
* @return a new {@link Prompt} instance with the augmented system message.
|
||||
*/
|
||||
public Prompt augmentSystemMessage(Function<SystemMessage, SystemMessage> systemMessageAugmenter) {
|
||||
|
||||
var messagesCopy = new ArrayList<>(this.messages);
|
||||
for (int i = 0; i <= this.messages.size() - 1; i++) {
|
||||
Message message = messagesCopy.get(i);
|
||||
if (message instanceof SystemMessage systemMessage) {
|
||||
messagesCopy.set(i, systemMessageAugmenter.apply(systemMessage));
|
||||
break;
|
||||
}
|
||||
if (i == 0) {
|
||||
// If no system message is found, create a new one with the provided text
|
||||
// and add it as the first item in the list.
|
||||
messagesCopy.add(0, systemMessageAugmenter.apply(new SystemMessage("")));
|
||||
}
|
||||
}
|
||||
|
||||
return new Prompt(messagesCopy, null == this.chatOptions ? null : this.chatOptions.copy());
|
||||
}
|
||||
|
||||
/**
|
||||
* Augments the last system message in the prompt with the provided text. If no system
|
||||
* message is found, a new one is created with the provided text.
|
||||
* @return a new {@link Prompt} instance with the augmented system message.
|
||||
*/
|
||||
public Prompt augmentSystemMessage(String newSystemText) {
|
||||
return augmentSystemMessage(systemMessage -> systemMessage.mutate().text(newSystemText).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Augments the last user message in the prompt with the provided function. If no user
|
||||
* message is found, a new one is created with the provided text.
|
||||
* @return a new {@link Prompt} instance with the augmented user message.
|
||||
*/
|
||||
public Prompt augmentUserMessage(Function<UserMessage, UserMessage> userMessageAugmenter) {
|
||||
|
||||
var messagesCopy = new ArrayList<>(this.messages);
|
||||
for (int i = messagesCopy.size() - 1; i >= 0; i--) {
|
||||
Message message = messagesCopy.get(i);
|
||||
@@ -186,11 +233,9 @@ public class Prompt implements ModelRequest<List<Message>> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a copy of the prompt, replacing the text content of the last UserMessage
|
||||
* with the provided text. If no UserMessage exists, a new one with the given text is
|
||||
* added.
|
||||
* @param newUserText The new text content for the last user message.
|
||||
* @return A new Prompt instance with the augmented user message text.
|
||||
* Augments the last user message in the prompt with the provided text. If no user
|
||||
* message is found, a new one is created with the provided text.
|
||||
* @return a new {@link Prompt} instance with the augmented user message.
|
||||
*/
|
||||
public Prompt augmentUserMessage(String newUserText) {
|
||||
return augmentUserMessage(userMessage -> userMessage.mutate().text(newUserText).build());
|
||||
|
||||
@@ -152,10 +152,90 @@ class PromptTests {
|
||||
|
||||
Prompt copy = prompt.augmentUserMessage(message -> message.mutate().text("How are you?").build());
|
||||
|
||||
assertThat(copy.getInstructions().get(copy.getInstructions().size() - 1)).isInstanceOf(UserMessage.class);
|
||||
assertThat(copy.getUserMessage()).isNotNull();
|
||||
assertThat(copy.getUserMessage().getText()).isEqualTo("How are you?");
|
||||
assertThat(prompt.getUserMessage()).isNotNull();
|
||||
assertThat(prompt.getUserMessage().getText()).isEqualTo("");
|
||||
}
|
||||
|
||||
@Test
|
||||
void getSystemMessageWhenSingle() {
|
||||
Prompt prompt = Prompt.builder().messages(new SystemMessage("Hello")).build();
|
||||
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
|
||||
}
|
||||
|
||||
@Test
|
||||
void getSystemMessageWhenMultiple() {
|
||||
Prompt prompt = Prompt.builder()
|
||||
.messages(new SystemMessage("Hello"), new SystemMessage("How are you?"))
|
||||
.build();
|
||||
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
|
||||
}
|
||||
|
||||
@Test
|
||||
void getSystemMessageWhenNone() {
|
||||
Prompt prompt = Prompt.builder().messages(new UserMessage("You'll be back!")).build();
|
||||
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("");
|
||||
|
||||
prompt = Prompt.builder().messages(List.of()).build();
|
||||
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("");
|
||||
}
|
||||
|
||||
@Test
|
||||
void augmentSystemMessageWhenSingle() {
|
||||
Prompt prompt = Prompt.builder().messages(new SystemMessage("Hello")).build();
|
||||
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
|
||||
|
||||
Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("How are you?").build());
|
||||
|
||||
assertThat(copy.getSystemMessage()).isNotNull();
|
||||
assertThat(copy.getSystemMessage().getText()).isEqualTo("How are you?");
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
|
||||
}
|
||||
|
||||
@Test
|
||||
void augmentSystemMessageWhenMultiple() {
|
||||
Prompt prompt = Prompt.builder()
|
||||
.messages(new SystemMessage("Hello"), new SystemMessage("How are you?"))
|
||||
.build();
|
||||
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
|
||||
|
||||
Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("What about you?").build());
|
||||
|
||||
assertThat(copy.getSystemMessage()).isNotNull();
|
||||
assertThat(copy.getSystemMessage().getText()).isEqualTo("What about you?");
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello");
|
||||
}
|
||||
|
||||
@Test
|
||||
void augmentSystemMessageWhenNone() {
|
||||
Prompt prompt = Prompt.builder().messages(new UserMessage("You'll be back!")).build();
|
||||
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("");
|
||||
|
||||
Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("How are you?").build());
|
||||
|
||||
assertThat(copy.getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(copy.getSystemMessage()).isNotNull();
|
||||
assertThat(copy.getSystemMessage().getText()).isEqualTo("How are you?");
|
||||
assertThat(prompt.getSystemMessage()).isNotNull();
|
||||
assertThat(prompt.getSystemMessage().getText()).isEqualTo("");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -25,8 +25,6 @@ import java.util.stream.Collectors;
|
||||
|
||||
import reactor.core.scheduler.Scheduler;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
@@ -100,16 +98,6 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated in favour of {@link #before(ChatClientRequest, AdvisorChain)}
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public AdvisedRequest before(AdvisedRequest advisedRequest) {
|
||||
ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest();
|
||||
return AdvisedRequest.from(before(chatClientRequest, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable AdvisorChain advisorChain) {
|
||||
Map<String, Object> context = new HashMap<>(chatClientRequest.context());
|
||||
@@ -163,16 +151,6 @@ public final class RetrievalAugmentationAdvisor implements BaseAdvisor {
|
||||
return Map.entry(query, documents);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated in favour of {@link #after(ChatClientResponse, AdvisorChain)}
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public AdvisedResponse after(AdvisedResponse advisedResponse) {
|
||||
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
|
||||
return AdvisedResponse.from(after(chatClientResponse, null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientResponse after(ChatClientResponse chatClientResponse, @Nullable AdvisorChain advisorChain) {
|
||||
ChatResponse.Builder chatResponseBuilder;
|
||||
|
||||
@@ -18,6 +18,7 @@ package org.springframework.ai.vectorstore.pgvector;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
@@ -26,6 +27,7 @@ import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.ArgumentMatchers;
|
||||
import org.mockito.Mockito;
|
||||
import org.postgresql.ds.PGSimpleDataSource;
|
||||
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
|
||||
import org.testcontainers.containers.PostgreSQLContainer;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -75,22 +77,20 @@ class PgVectorStoreWithChatMemoryAdvisorIT {
|
||||
return chatModel;
|
||||
}
|
||||
|
||||
private static void initStore(PgVectorStore store) throws Exception {
|
||||
private static void initStore(PgVectorStore store, String conversationId) {
|
||||
store.afterPropertiesSet();
|
||||
// fill the store
|
||||
store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")),
|
||||
new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER"))));
|
||||
store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", conversationId)),
|
||||
new Document("Tell me a bad joke", Map.of("conversationId", conversationId, "messageType", "USER"))));
|
||||
}
|
||||
|
||||
private static PgVectorStore createPgVectorStoreUsingTestcontainer(EmbeddingModel embeddingModel) throws Exception {
|
||||
JdbcTemplate jdbcTemplate = createJdbcTemplateWithConnectionToTestcontainer();
|
||||
PgVectorStore vectorStore = PgVectorStore.builder(jdbcTemplate, embeddingModel)
|
||||
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
|
||||
.dimensions(3) // match
|
||||
// embeddings
|
||||
.initializeSchema(true)
|
||||
.build();
|
||||
initStore(vectorStore);
|
||||
return vectorStore;
|
||||
}
|
||||
|
||||
private static @NotNull JdbcTemplate createJdbcTemplateWithConnectionToTestcontainer() {
|
||||
@@ -105,7 +105,7 @@ class PgVectorStoreWithChatMemoryAdvisorIT {
|
||||
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
verify(chatModel).call(promptCaptor.capture());
|
||||
assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualTo("""
|
||||
assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualToIgnoringWhitespace("""
|
||||
|
||||
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
|
||||
|
||||
@@ -129,19 +129,59 @@ class PgVectorStoreWithChatMemoryAdvisorIT {
|
||||
// faked embedding model
|
||||
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
|
||||
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
|
||||
String conversationId = UUID.randomUUID().toString();
|
||||
initStore(store, conversationId);
|
||||
|
||||
// do the chat
|
||||
ChatClient.builder(chatModel)
|
||||
.build()
|
||||
.prompt()
|
||||
.user("joke")
|
||||
.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
|
||||
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
|
||||
.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
|
||||
.call()
|
||||
.chatResponse();
|
||||
|
||||
verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel);
|
||||
}
|
||||
|
||||
@Test
|
||||
void advisedChatShouldHaveSimilarMessagesFromVectorStoreWhenSystemMessageProvided() throws Exception {
|
||||
// faked ChatModel
|
||||
ChatModel chatModel = chatModelAlwaysReturnsTheSameReply();
|
||||
// faked embedding model
|
||||
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
|
||||
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
|
||||
String conversationId = UUID.randomUUID().toString();
|
||||
initStore(store, conversationId);
|
||||
|
||||
// do the chat
|
||||
ChatClient.builder(chatModel)
|
||||
.build()
|
||||
.prompt()
|
||||
.system("You are a helpful assistant.")
|
||||
.user("joke")
|
||||
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
|
||||
.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
|
||||
.call()
|
||||
.chatResponse();
|
||||
|
||||
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
verify(chatModel).call(promptCaptor.capture());
|
||||
assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class);
|
||||
assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualToIgnoringWhitespace("""
|
||||
You are a helpful assistant.
|
||||
|
||||
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
|
||||
|
||||
---------------------
|
||||
LONG_TERM_MEMORY:
|
||||
Tell me a good joke
|
||||
Tell me a bad joke
|
||||
---------------------
|
||||
""");
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
|
||||
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);
|
||||
|
||||
Reference in New Issue
Block a user