Make ChatClient and Advisor APIs more robust - Part 1
- Introduce “ChatClientRequest” and “ChatClientResponse” for propagating requests/responses in a ChatClient advisor chain. - Structure a Prompt at the beginning of the chain, to ensure a consistent view across execution chain and observations. Any template is rendered at the beginning so that every advisor doesn’t have to do it again. - Improve observations to include the complete view of the prompt messages, instead of only considering userText and systemText. - Remove legacy “around” advisor type concept. - Keep backward compatibility for AdvisedRequest, AdvisedResponse, and legacy Advisor APIs. Relates to gh-2655 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
This commit is contained in:
committed by
Christian Tzolov
parent
593083980b
commit
1f59ccadad
@@ -42,6 +42,7 @@ import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.test.CurlyBracketEscaper;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -189,7 +190,7 @@ class AnthropicChatClientIT {
|
||||
.user(u -> u
|
||||
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
|
||||
+ "{format}")
|
||||
.param("format", outputConverter.getFormat()))
|
||||
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.test.CurlyBracketEscaper;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
@@ -83,7 +84,7 @@ public class AzureOpenAiChatClientIT {
|
||||
.user(u -> u
|
||||
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
|
||||
+ "{format}")
|
||||
.param("format", outputConverter.getFormat()))
|
||||
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
|
||||
.stream()
|
||||
.chatResponse();
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.test.CurlyBracketEscaper;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -182,7 +183,7 @@ class BedrockConverseChatClientIT {
|
||||
.user(u -> u
|
||||
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
|
||||
+ "{format}")
|
||||
.param("format", outputConverter.getFormat()))
|
||||
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
|
||||
.stream()
|
||||
.chatResponse();
|
||||
|
||||
|
||||
@@ -28,12 +28,14 @@ import org.slf4j.LoggerFactory;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi;
|
||||
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
|
||||
import org.springframework.ai.test.CurlyBracketEscaper;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -198,10 +200,11 @@ class MistralAiChatClientIT {
|
||||
// @formatter:off
|
||||
Flux<String> chatResponse = ChatClient.create(this.chatModel)
|
||||
.prompt()
|
||||
.advisors(new SimpleLoggerAdvisor())
|
||||
.user(u -> u
|
||||
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
|
||||
+ "{format}")
|
||||
.param("format", outputConverter.getFormat()))
|
||||
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
|
||||
.stream()
|
||||
.content();
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
|
||||
import org.springframework.ai.openai.api.tool.MockWeatherService;
|
||||
import org.springframework.ai.openai.testutils.AbstractIT;
|
||||
import org.springframework.ai.test.CurlyBracketEscaper;
|
||||
import org.springframework.ai.tool.function.FunctionToolCallback;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
@@ -220,7 +221,7 @@ class OpenAiChatClientIT extends AbstractIT {
|
||||
.user(u -> u
|
||||
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
|
||||
+ "{format}")
|
||||
.param("format", outputConverter.getFormat()))
|
||||
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
|
||||
.stream()
|
||||
.chatResponse();
|
||||
|
||||
|
||||
@@ -156,6 +156,8 @@ public interface ChatClient {
|
||||
@Nullable
|
||||
<T> T entity(Class<T> type);
|
||||
|
||||
ChatClientResponse chatClientResponse();
|
||||
|
||||
@Nullable
|
||||
ChatResponse chatResponse();
|
||||
|
||||
@@ -172,6 +174,8 @@ public interface ChatClient {
|
||||
|
||||
interface StreamResponseSpec {
|
||||
|
||||
Flux<ChatClientResponse> chatClientResponse();
|
||||
|
||||
Flux<ChatResponse> chatResponse();
|
||||
|
||||
Flux<String> content();
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
/**
|
||||
* Common attributes used in {@link ChatClient} context.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public enum ChatClientAttributes {
|
||||
|
||||
//@formatter:off
|
||||
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
ADVISORS("spring.ai.chat.client.advisors"),
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
CHAT_MODEL("spring.ai.chat.client.model"),
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
OUTPUT_FORMAT("spring.ai.chat.client.output.format"),
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
USER_PARAMS("spring.ai.chat.client.user.params"),
|
||||
@Deprecated // Only for backward compatibility until the next release.
|
||||
SYSTEM_PARAMS("spring.ai.chat.client.system.params");
|
||||
|
||||
//@formatter:on
|
||||
|
||||
private final String key;
|
||||
|
||||
ChatClientAttributes(String key) {
|
||||
this.key = key;
|
||||
}
|
||||
|
||||
public String getKey() {
|
||||
return key;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Represents a request processed by a {@link ChatClient} that ultimately is used to build
|
||||
* a {@link Prompt} to be sent to an AI model.
|
||||
*
|
||||
* @param prompt The prompt to be sent to the AI model
|
||||
* @param context The contextual data through the execution chain
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public record ChatClientRequest(Prompt prompt, Map<String, Object> context) {
|
||||
|
||||
public ChatClientRequest {
|
||||
Assert.notNull(prompt, "prompt cannot be null");
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
Assert.noNullElements(context.keySet(), "context keys cannot be null");
|
||||
}
|
||||
|
||||
public Builder mutate() {
|
||||
return new Builder().prompt(this.prompt).context(this.context);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static final class Builder {
|
||||
|
||||
private Prompt prompt;
|
||||
|
||||
private Map<String, Object> context = new HashMap<>();
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
public Builder prompt(Prompt prompt) {
|
||||
Assert.notNull(prompt, "prompt cannot be null");
|
||||
this.prompt = prompt;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder context(Map<String, Object> context) {
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
this.context.putAll(context);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder context(String key, Object value) {
|
||||
Assert.notNull(key, "key cannot be null");
|
||||
this.context.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientRequest build() {
|
||||
return new ChatClientRequest(prompt, context);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Represents a response returned by a {@link ChatClient}.
|
||||
*
|
||||
* @param chatResponse The response returned by the AI model
|
||||
* @param context The contextual data propagated through the execution chain
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public record ChatClientResponse(@Nullable ChatResponse chatResponse, Map<String, Object> context) {
|
||||
|
||||
public ChatClientResponse {
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
Assert.noNullElements(context.keySet(), "context keys cannot be null");
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private ChatResponse chatResponse;
|
||||
|
||||
private Map<String, Object> context = new HashMap<>();
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
public Builder chatResponse(ChatResponse chatResponse) {
|
||||
this.chatResponse = chatResponse;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder context(Map<String, Object> context) {
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
this.context.putAll(context);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder context(String key, Object value) {
|
||||
Assert.notNull(key, "key cannot be null");
|
||||
this.context.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatClientResponse build() {
|
||||
return new ChatClientResponse(this.chatResponse, this.context);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -22,7 +22,6 @@ import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -33,17 +32,19 @@ import java.util.function.Consumer;
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.ChatModelCallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.ChatModelStreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.ai.tool.ToolCallbacks;
|
||||
import org.springframework.lang.NonNull;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
|
||||
@@ -62,10 +63,6 @@ import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.converter.BeanOutputConverter;
|
||||
import org.springframework.ai.converter.StructuredOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.ai.tool.ToolCallbackProvider;
|
||||
import org.springframework.ai.tool.ToolCallbacks;
|
||||
import org.springframework.core.Ordered;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.lang.Nullable;
|
||||
@@ -98,14 +95,10 @@ public class DefaultChatClient implements ChatClient {
|
||||
this.defaultChatClientRequest = defaultChatClientRequest;
|
||||
}
|
||||
|
||||
private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest,
|
||||
@Nullable String formatParam) {
|
||||
private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) {
|
||||
Assert.notNull(inputRequest, "inputRequest cannot be null");
|
||||
|
||||
Map<String, Object> advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams());
|
||||
if (StringUtils.hasText(formatParam)) {
|
||||
advisorContext.put("formatParam", formatParam);
|
||||
}
|
||||
|
||||
// Process userText, media and messages before creating the AdvisedRequest.
|
||||
String userText = inputRequest.userText;
|
||||
@@ -131,11 +124,12 @@ public class DefaultChatClient implements ChatClient {
|
||||
}
|
||||
|
||||
return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions,
|
||||
media, inputRequest.functionNames, inputRequest.functionCallbacks, messages, inputRequest.userParams,
|
||||
media, inputRequest.toolNames, inputRequest.toolCallbacks, messages, inputRequest.userParams,
|
||||
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext,
|
||||
inputRequest.toolContext);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest,
|
||||
ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) {
|
||||
|
||||
@@ -391,11 +385,25 @@ public class DefaultChatClient implements ChatClient {
|
||||
|
||||
public static class DefaultCallResponseSpec implements CallResponseSpec {
|
||||
|
||||
private final DefaultChatClientRequestSpec request;
|
||||
private final ChatClientRequest request;
|
||||
|
||||
public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) {
|
||||
Assert.notNull(request, "request cannot be null");
|
||||
this.request = request;
|
||||
private final BaseAdvisorChain advisorChain;
|
||||
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
private final ChatClientObservationConvention observationConvention;
|
||||
|
||||
public DefaultCallResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain,
|
||||
ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) {
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
Assert.notNull(advisorChain, "advisorChain cannot be null");
|
||||
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
|
||||
Assert.notNull(observationConvention, "observationConvention cannot be null");
|
||||
|
||||
this.request = chatClientRequest;
|
||||
this.advisorChain = advisorChain;
|
||||
this.observationRegistry = observationRegistry;
|
||||
this.observationConvention = observationConvention;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -419,7 +427,8 @@ public class DefaultChatClient implements ChatClient {
|
||||
|
||||
protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> outputConverter) {
|
||||
Assert.notNull(outputConverter, "structuredOutputConverter cannot be null");
|
||||
var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat());
|
||||
var chatResponse = doGetObservableChatClientResponse(this.request, outputConverter.getFormat())
|
||||
.chatResponse();
|
||||
var responseContent = getContentFromChatResponse(chatResponse);
|
||||
if (responseContent == null) {
|
||||
return new ResponseEntity<>(chatResponse, null);
|
||||
@@ -452,7 +461,8 @@ public class DefaultChatClient implements ChatClient {
|
||||
|
||||
@Nullable
|
||||
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> outputConverter) {
|
||||
var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat());
|
||||
var chatResponse = doGetObservableChatClientResponse(this.request, outputConverter.getFormat())
|
||||
.chatResponse();
|
||||
var stringResponse = getContentFromChatResponse(chatResponse);
|
||||
if (stringResponse == null) {
|
||||
return null;
|
||||
@@ -460,38 +470,85 @@ public class DefaultChatClient implements ChatClient {
|
||||
return outputConverter.convert(stringResponse);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private ChatResponse doGetChatResponse() {
|
||||
return this.doGetObservableChatResponse(this.request, null);
|
||||
@Override
|
||||
public ChatClientResponse chatClientResponse() {
|
||||
return doGetObservableChatClientResponse(this.request);
|
||||
}
|
||||
|
||||
@Override
|
||||
@Nullable
|
||||
private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec inputRequest,
|
||||
@Nullable String formatParam) {
|
||||
public ChatResponse chatResponse() {
|
||||
return doGetObservableChatClientResponse(this.request).chatResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
@Nullable
|
||||
public String content() {
|
||||
ChatResponse chatResponse = doGetObservableChatClientResponse(this.request).chatResponse();
|
||||
return getContentFromChatResponse(chatResponse);
|
||||
}
|
||||
|
||||
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest) {
|
||||
return doGetObservableChatClientResponse(chatClientRequest, null);
|
||||
}
|
||||
|
||||
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest,
|
||||
@Nullable String outputFormat) {
|
||||
ChatClientRequest formattedChatClientRequest = StringUtils.hasText(outputFormat)
|
||||
? addFormatInstructionsToPrompt(chatClientRequest, outputFormat) : chatClientRequest;
|
||||
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.withRequest(inputRequest)
|
||||
.withFormat(formatParam)
|
||||
.withStream(false)
|
||||
.request(formattedChatClientRequest)
|
||||
.stream(false)
|
||||
.withFormat(outputFormat)
|
||||
.build();
|
||||
|
||||
var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
|
||||
inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
|
||||
() -> observationContext, inputRequest.getObservationRegistry());
|
||||
return observation.observe(() -> doGetChatResponse(inputRequest, formatParam, observation));
|
||||
var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(observationConvention,
|
||||
DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry);
|
||||
var chatClientResponse = observation.observe(() -> {
|
||||
// Apply the advisor chain that terminates with the ChatModelCallAdvisor.
|
||||
return advisorChain.nextCall(formattedChatClientRequest);
|
||||
});
|
||||
return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build();
|
||||
}
|
||||
|
||||
private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec,
|
||||
@Nullable String formatParam, Observation parentObservation) {
|
||||
@NonNull
|
||||
private static ChatClientRequest addFormatInstructionsToPrompt(ChatClientRequest chatClientRequest,
|
||||
String outputFormat) {
|
||||
List<Message> originalMessages = chatClientRequest.prompt().getInstructions();
|
||||
|
||||
AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec, formatParam);
|
||||
if (CollectionUtils.isEmpty(originalMessages)) {
|
||||
return chatClientRequest;
|
||||
}
|
||||
|
||||
// Apply the around advisor chain that terminates with the last model call
|
||||
// advisor.
|
||||
AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChainBuilder.build()
|
||||
.nextAroundCall(advisedRequest);
|
||||
// Create a copy of the message list to avoid modifying the original.
|
||||
List<Message> modifiedMessages = new ArrayList<>(originalMessages);
|
||||
|
||||
return advisedResponse.response();
|
||||
// Get the last message (without removing it from original list)
|
||||
Message lastMessage = modifiedMessages.get(modifiedMessages.size() - 1);
|
||||
|
||||
// If the last message is a UserMessage, replace it with the modified version
|
||||
if (lastMessage instanceof UserMessage userMessage) {
|
||||
// Remove last message
|
||||
modifiedMessages.remove(modifiedMessages.size() - 1);
|
||||
|
||||
// Create new user message with format instructions
|
||||
UserMessage userMessageWithFormat = userMessage.mutate()
|
||||
.text(userMessage.getText() + System.lineSeparator() + outputFormat)
|
||||
.build();
|
||||
|
||||
// Add modified message back
|
||||
modifiedMessages.add(userMessageWithFormat);
|
||||
|
||||
// Build new ChatClientRequest preserving all properties but with modified
|
||||
// prompt
|
||||
return ChatClientRequest.builder()
|
||||
.prompt(chatClientRequest.prompt().mutate().messages(modifiedMessages).build())
|
||||
.context(Map.copyOf(chatClientRequest.context()))
|
||||
.build();
|
||||
}
|
||||
|
||||
return chatClientRequest;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@@ -503,53 +560,49 @@ public class DefaultChatClient implements ChatClient {
|
||||
.orElse(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
@Nullable
|
||||
public ChatResponse chatResponse() {
|
||||
return doGetChatResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
@Nullable
|
||||
public String content() {
|
||||
ChatResponse chatResponse = doGetChatResponse();
|
||||
return getContentFromChatResponse(chatResponse);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class DefaultStreamResponseSpec implements StreamResponseSpec {
|
||||
|
||||
private final DefaultChatClientRequestSpec request;
|
||||
private final ChatClientRequest request;
|
||||
|
||||
public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) {
|
||||
Assert.notNull(request, "request cannot be null");
|
||||
this.request = request;
|
||||
private final BaseAdvisorChain advisorChain;
|
||||
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
private final ChatClientObservationConvention observationConvention;
|
||||
|
||||
public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain,
|
||||
ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) {
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
Assert.notNull(advisorChain, "advisorChain cannot be null");
|
||||
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
|
||||
Assert.notNull(observationConvention, "observationConvention cannot be null");
|
||||
|
||||
this.request = chatClientRequest;
|
||||
this.advisorChain = advisorChain;
|
||||
this.observationRegistry = observationRegistry;
|
||||
this.observationConvention = observationConvention;
|
||||
}
|
||||
|
||||
private Flux<ChatResponse> doGetObservableFluxChatResponse(DefaultChatClientRequestSpec inputRequest) {
|
||||
private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) {
|
||||
return Flux.deferContextual(contextView -> {
|
||||
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.withRequest(inputRequest)
|
||||
.withStream(true)
|
||||
.request(chatClientRequest)
|
||||
.stream(true)
|
||||
.build();
|
||||
|
||||
Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
|
||||
inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
|
||||
() -> observationContext, inputRequest.getObservationRegistry());
|
||||
observationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext,
|
||||
observationRegistry);
|
||||
|
||||
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null))
|
||||
.start();
|
||||
|
||||
var initialAdvisedRequest = toAdvisedRequest(inputRequest, null);
|
||||
|
||||
// @formatter:off
|
||||
// Apply the around advisor chain that terminates with the last model call advisor.
|
||||
Flux<AdvisedResponse> stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest);
|
||||
|
||||
return stream
|
||||
.map(AdvisedResponse::response)
|
||||
// Apply the advisor chain that terminates with the ChatModelStreamAdvisor.
|
||||
return advisorChain.nextStream(chatClientRequest)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
|
||||
@@ -558,19 +611,29 @@ public class DefaultChatClient implements ChatClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> chatResponse() {
|
||||
public Flux<ChatClientResponse> chatClientResponse() {
|
||||
return doGetObservableFluxChatResponse(this.request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> chatResponse() {
|
||||
return doGetObservableFluxChatResponse(this.request).mapNotNull(ChatClientResponse::chatResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<String> content() {
|
||||
return doGetObservableFluxChatResponse(this.request).map(r -> {
|
||||
if (r.getResult() == null || r.getResult().getOutput() == null
|
||||
|| r.getResult().getOutput().getText() == null) {
|
||||
return "";
|
||||
}
|
||||
return r.getResult().getOutput().getText();
|
||||
}).filter(StringUtils::hasLength);
|
||||
// @formatter:off
|
||||
return doGetObservableFluxChatResponse(this.request)
|
||||
.mapNotNull(ChatClientResponse::chatResponse)
|
||||
.map(r -> {
|
||||
if (r.getResult() == null || r.getResult().getOutput() == null
|
||||
|| r.getResult().getOutput().getText() == null) {
|
||||
return "";
|
||||
}
|
||||
return r.getResult().getOutput().getText();
|
||||
})
|
||||
.filter(StringUtils::hasLength);
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
}
|
||||
@@ -579,15 +642,15 @@ public class DefaultChatClient implements ChatClient {
|
||||
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
private final ChatClientObservationConvention customObservationConvention;
|
||||
private final ChatClientObservationConvention observationConvention;
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
private final List<Media> media = new ArrayList<>();
|
||||
|
||||
private final List<String> functionNames = new ArrayList<>();
|
||||
private final List<String> toolNames = new ArrayList<>();
|
||||
|
||||
private final List<FunctionCallback> functionCallbacks = new ArrayList<>();
|
||||
private final List<FunctionCallback> toolCallbacks = new ArrayList<>();
|
||||
|
||||
private final List<Message> messages = new ArrayList<>();
|
||||
|
||||
@@ -614,25 +677,24 @@ public class DefaultChatClient implements ChatClient {
|
||||
|
||||
/* copy constructor */
|
||||
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
|
||||
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
|
||||
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
|
||||
ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext);
|
||||
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks,
|
||||
ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
|
||||
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext);
|
||||
}
|
||||
|
||||
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
|
||||
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
|
||||
List<FunctionCallback> functionCallbacks, List<Message> messages, List<String> functionNames,
|
||||
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
|
||||
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
|
||||
@Nullable ChatClientObservationConvention customObservationConvention,
|
||||
Map<String, Object> toolContext) {
|
||||
List<FunctionCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
|
||||
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
|
||||
ObservationRegistry observationRegistry,
|
||||
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext) {
|
||||
|
||||
Assert.notNull(chatModel, "chatModel cannot be null");
|
||||
Assert.notNull(userParams, "userParams cannot be null");
|
||||
Assert.notNull(systemParams, "systemParams cannot be null");
|
||||
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
|
||||
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
Assert.notNull(messages, "messages cannot be null");
|
||||
Assert.notNull(functionNames, "functionNames cannot be null");
|
||||
Assert.notNull(toolNames, "toolNames cannot be null");
|
||||
Assert.notNull(media, "media cannot be null");
|
||||
Assert.notNull(advisors, "advisors cannot be null");
|
||||
Assert.notNull(advisorParams, "advisorParams cannot be null");
|
||||
@@ -648,58 +710,21 @@ public class DefaultChatClient implements ChatClient {
|
||||
this.systemText = systemText;
|
||||
this.systemParams.putAll(systemParams);
|
||||
|
||||
this.functionNames.addAll(functionNames);
|
||||
this.functionCallbacks.addAll(functionCallbacks);
|
||||
this.toolNames.addAll(toolNames);
|
||||
this.toolCallbacks.addAll(toolCallbacks);
|
||||
this.messages.addAll(messages);
|
||||
this.media.addAll(media);
|
||||
this.advisors.addAll(advisors);
|
||||
this.advisorParams.putAll(advisorParams);
|
||||
this.observationRegistry = observationRegistry;
|
||||
this.customObservationConvention = customObservationConvention != null ? customObservationConvention
|
||||
this.observationConvention = observationConvention != null ? observationConvention
|
||||
: DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
|
||||
this.toolContext.putAll(toolContext);
|
||||
|
||||
// @formatter:off
|
||||
// At the stack bottom add the non-streaming and streaming model call advisors.
|
||||
// They play the role of the last advisor in the around advisor chain.
|
||||
this.advisors.add(new CallAroundAdvisor() {
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return CallAroundAdvisor.class.getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return Ordered.LOWEST_PRECEDENCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext()));
|
||||
}
|
||||
});
|
||||
|
||||
this.advisors.add(new StreamAroundAdvisor() {
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return StreamAroundAdvisor.class.getSimpleName();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return Ordered.LOWEST_PRECEDENCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
return chatModel.stream(advisedRequest.toPrompt())
|
||||
.map(chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext())))
|
||||
.publishOn(Schedulers.boundedElastic()); // TODO add option to disable.
|
||||
}
|
||||
});
|
||||
// @formatter:on
|
||||
// At the stack bottom add the model call advisors.
|
||||
// They play the role of the last advisors in the advisor chain.
|
||||
this.advisors.add(new ChatModelCallAdvisor(chatModel));
|
||||
this.advisors.add(new ChatModelStreamAdvisor(chatModel));
|
||||
|
||||
this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry)
|
||||
.pushAll(this.advisors);
|
||||
@@ -710,7 +735,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
}
|
||||
|
||||
private ChatClientObservationConvention getCustomObservationConvention() {
|
||||
return this.customObservationConvention;
|
||||
return this.observationConvention;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
@@ -753,11 +778,11 @@ public class DefaultChatClient implements ChatClient {
|
||||
}
|
||||
|
||||
public List<String> getFunctionNames() {
|
||||
return this.functionNames;
|
||||
return this.toolNames;
|
||||
}
|
||||
|
||||
public List<FunctionCallback> getFunctionCallbacks() {
|
||||
return this.functionCallbacks;
|
||||
return this.toolCallbacks;
|
||||
}
|
||||
|
||||
public Map<String, Object> getToolContext() {
|
||||
@@ -770,8 +795,8 @@ public class DefaultChatClient implements ChatClient {
|
||||
*/
|
||||
public Builder mutate() {
|
||||
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient
|
||||
.builder(this.chatModel, this.observationRegistry, this.customObservationConvention)
|
||||
.defaultFunctions(StringUtils.toStringArray(this.functionNames));
|
||||
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
|
||||
.defaultTools(StringUtils.toStringArray(this.toolNames));
|
||||
|
||||
if (StringUtils.hasText(this.userText)) {
|
||||
builder.defaultUser(
|
||||
@@ -787,7 +812,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
}
|
||||
|
||||
builder.addMessages(this.messages);
|
||||
builder.addToolCallbacks(this.functionCallbacks);
|
||||
builder.addToolCallbacks(this.toolCallbacks);
|
||||
builder.addToolContext(this.toolContext);
|
||||
|
||||
return builder;
|
||||
@@ -843,7 +868,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
public ChatClientRequestSpec tools(String... toolNames) {
|
||||
Assert.notNull(toolNames, "toolNames cannot be null");
|
||||
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
|
||||
this.functionNames.addAll(List.of(toolNames));
|
||||
this.toolNames.addAll(List.of(toolNames));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -851,7 +876,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) {
|
||||
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
|
||||
this.functionCallbacks.addAll(List.of(toolCallbacks));
|
||||
this.toolCallbacks.addAll(List.of(toolCallbacks));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -859,7 +884,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
public ChatClientRequestSpec tools(List<ToolCallback> toolCallbacks) {
|
||||
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
|
||||
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
|
||||
this.functionCallbacks.addAll(toolCallbacks);
|
||||
this.toolCallbacks.addAll(toolCallbacks);
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -867,7 +892,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
public ChatClientRequestSpec tools(Object... toolObjects) {
|
||||
Assert.notNull(toolObjects, "toolObjects cannot be null");
|
||||
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
|
||||
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
|
||||
this.toolCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -876,7 +901,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null");
|
||||
Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements");
|
||||
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
|
||||
this.functionCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks()));
|
||||
this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks()));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
@@ -890,7 +915,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
|
||||
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
|
||||
Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements");
|
||||
this.functionCallbacks.addAll(Arrays.asList(functionCallbacks));
|
||||
this.toolCallbacks.addAll(Arrays.asList(functionCallbacks));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -973,17 +998,22 @@ public class DefaultChatClient implements ChatClient {
|
||||
}
|
||||
|
||||
public CallResponseSpec call() {
|
||||
return new DefaultCallResponseSpec(this);
|
||||
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
|
||||
return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
|
||||
observationRegistry, observationConvention);
|
||||
}
|
||||
|
||||
public StreamResponseSpec stream() {
|
||||
return new DefaultStreamResponseSpec(this);
|
||||
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
|
||||
return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
|
||||
observationRegistry, observationConvention);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Prompt
|
||||
|
||||
@Deprecated // never used, to be removed
|
||||
public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
@@ -1015,6 +1045,7 @@ public class DefaultChatClient implements ChatClient {
|
||||
|
||||
}
|
||||
|
||||
@Deprecated // never used, to be removed
|
||||
public static class DefaultStreamPromptResponseSpec implements StreamPromptResponseSpec {
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.core.Ordered;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A {@link CallAdvisor} that uses a {@link ChatModel} to generate a response.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public final class ChatModelCallAdvisor implements CallAdvisor {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
public ChatModelCallAdvisor(ChatModel chatModel) {
|
||||
this.chatModel = chatModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
ChatResponse chatResponse = chatModel.call(chatClientRequest.prompt());
|
||||
return ChatClientResponse.builder()
|
||||
.chatResponse(chatResponse)
|
||||
.context(Map.copyOf(chatClientRequest.context()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "call";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return Ordered.LOWEST_PRECEDENCE;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.*;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.core.Ordered;
|
||||
import org.springframework.util.Assert;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.scheduler.Schedulers;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A {@link StreamAdvisor} that uses a {@link ChatModel} to generate a streaming response.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public final class ChatModelStreamAdvisor implements StreamAdvisor {
|
||||
|
||||
private final ChatModel chatModel;
|
||||
|
||||
public ChatModelStreamAdvisor(ChatModel chatModel) {
|
||||
this.chatModel = chatModel;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain) {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
return chatModel.stream(chatClientRequest.prompt())
|
||||
.map(chatResponse -> ChatClientResponse.builder()
|
||||
.chatResponse(chatResponse)
|
||||
.context(Map.copyOf(chatClientRequest.context()))
|
||||
.build())
|
||||
.publishOn(Schedulers.boundedElastic()); // TODO add option to disable
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "stream";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOrder() {
|
||||
return Ordered.LOWEST_PRECEDENCE;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -23,15 +23,18 @@ import java.util.concurrent.ConcurrentLinkedDeque;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
|
||||
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext;
|
||||
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
|
||||
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation;
|
||||
@@ -41,16 +44,16 @@ import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Implementation of the {@link CallAroundAdvisorChain} and
|
||||
* {@link StreamAroundAdvisorChain}. Used by the
|
||||
* Default implementation for the {@link BaseAdvisorChain}. Used by the
|
||||
* {@link org.springframework.ai.chat.client.ChatClient} to delegate the call to the next
|
||||
* {@link CallAroundAdvisor} or {@link StreamAroundAdvisor} in the chain.
|
||||
* {@link CallAdvisor} or {@link StreamAdvisor} in the chain.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, StreamAroundAdvisorChain {
|
||||
public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
|
||||
|
||||
public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();
|
||||
|
||||
@@ -77,7 +80,42 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
if (this.callAroundAdvisors.isEmpty()) {
|
||||
throw new IllegalStateException("No CallAdvisors available to execute");
|
||||
}
|
||||
|
||||
var advisor = this.callAroundAdvisors.pop();
|
||||
|
||||
var observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
.chatClientRequest(chatClientRequest)
|
||||
.order(advisor.getOrder())
|
||||
.build();
|
||||
|
||||
return AdvisorObservationDocumentation.AI_ADVISOR
|
||||
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
|
||||
.observe(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof CallAdvisor callAdvisor) {
|
||||
return callAdvisor.adviseCall(chatClientRequest, this);
|
||||
}
|
||||
AdvisedResponse advisedResponse = advisor.aroundCall(AdvisedRequest.from(chatClientRequest), this);
|
||||
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
|
||||
observationContext.setChatClientResponse(chatClientResponse);
|
||||
return chatClientResponse;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #nextCall(ChatClientRequest)} instead
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
|
||||
Assert.notNull(advisedRequest, "the advisedRequest cannot be null");
|
||||
|
||||
if (this.callAroundAdvisors.isEmpty()) {
|
||||
throw new IllegalStateException("No AroundAdvisor available to execute");
|
||||
@@ -87,31 +125,40 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream
|
||||
|
||||
var observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
.advisorType(AdvisorObservationContext.Type.AROUND)
|
||||
.advisedRequest(advisedRequest)
|
||||
.advisorRequestContext(advisedRequest.adviseContext())
|
||||
.chatClientRequest(advisedRequest.toChatClientRequest())
|
||||
.order(advisor.getOrder())
|
||||
.build();
|
||||
|
||||
return AdvisorObservationDocumentation.AI_ADVISOR
|
||||
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
|
||||
.observe(() -> advisor.aroundCall(advisedRequest, this));
|
||||
.observe(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof CallAdvisor callAdvisor) {
|
||||
ChatClientResponse chatClientResponse = callAdvisor.adviseCall(advisedRequest.toChatClientRequest(),
|
||||
this);
|
||||
return AdvisedResponse.from(chatClientResponse);
|
||||
}
|
||||
AdvisedResponse advisedResponse = advisor.aroundCall(advisedRequest, this);
|
||||
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
|
||||
observationContext.setChatClientResponse(chatClientResponse);
|
||||
return advisedResponse;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
|
||||
public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest) {
|
||||
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
|
||||
|
||||
return Flux.deferContextual(contextView -> {
|
||||
if (this.streamAroundAdvisors.isEmpty()) {
|
||||
return Flux.error(new IllegalStateException("No AroundAdvisor available to execute"));
|
||||
return Flux.error(new IllegalStateException("No StreamAdvisors available to execute"));
|
||||
}
|
||||
|
||||
var advisor = this.streamAroundAdvisors.pop();
|
||||
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
.advisorType(AdvisorObservationContext.Type.AROUND)
|
||||
.advisedRequest(advisedRequest)
|
||||
.advisorRequestContext(advisedRequest.adviseContext())
|
||||
.chatClientRequest(chatClientRequest)
|
||||
.order(advisor.getOrder())
|
||||
.build();
|
||||
|
||||
@@ -121,10 +168,66 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream
|
||||
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
|
||||
|
||||
// @formatter:off
|
||||
return Flux.defer(() -> advisor.aroundStream(advisedRequest, this))
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
|
||||
return Flux.defer(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof StreamAdvisor streamAdvisor) {
|
||||
return streamAdvisor.adviseStream(chatClientRequest, this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
|
||||
}
|
||||
return advisor.aroundStream(AdvisedRequest.from(chatClientRequest), this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation))
|
||||
.map(AdvisedResponse::toChatClientResponse);
|
||||
});
|
||||
// @formatter:on
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #nextStream(ChatClientRequest)} instead.
|
||||
*/
|
||||
@Override
|
||||
@Deprecated
|
||||
public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
|
||||
Assert.notNull(advisedRequest, "the advisedRequest cannot be null");
|
||||
|
||||
return Flux.deferContextual(contextView -> {
|
||||
if (this.streamAroundAdvisors.isEmpty()) {
|
||||
return Flux.error(new IllegalStateException("No AroundAdvisor available to execute"));
|
||||
}
|
||||
|
||||
var advisor = this.streamAroundAdvisors.pop();
|
||||
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName(advisor.getName())
|
||||
.chatClientRequest(advisedRequest.toChatClientRequest())
|
||||
.order(advisor.getOrder())
|
||||
.build();
|
||||
|
||||
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null,
|
||||
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
|
||||
|
||||
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
|
||||
|
||||
// @formatter:off
|
||||
return Flux.defer(() -> {
|
||||
// Supports both deprecated and new API.
|
||||
if (advisor instanceof StreamAdvisor streamAdvisor) {
|
||||
return streamAdvisor.adviseStream(advisedRequest.toChatClientRequest(), this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation))
|
||||
.map(AdvisedResponse::from);
|
||||
}
|
||||
|
||||
return advisor.aroundStream(advisedRequest, this)
|
||||
.doOnError(observation::error)
|
||||
.doFinally(s -> observation.stop())
|
||||
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
|
||||
});
|
||||
// @formatter:on
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -20,10 +20,14 @@ import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
@@ -34,6 +38,7 @@ import org.springframework.ai.chat.prompt.PromptTemplate;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.function.FunctionCallingOptions;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -60,8 +65,10 @@ import org.springframework.util.StringUtils;
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @deprecated Use {@link ChatClientRequest} instead.
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@Deprecated
|
||||
public record AdvisedRequest(
|
||||
// @formatter:off
|
||||
ChatModel chatModel,
|
||||
@@ -77,6 +84,7 @@ public record AdvisedRequest(
|
||||
Map<String, Object> userParams,
|
||||
Map<String, Object> systemParams,
|
||||
List<Advisor> advisors,
|
||||
@Deprecated // Not really used. Use "adviseContext" instead.
|
||||
Map<String, Object> advisorParams,
|
||||
Map<String, Object> adviseContext,
|
||||
Map<String, Object> toolContext
|
||||
@@ -139,6 +147,52 @@ public record AdvisedRequest(
|
||||
return builder;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public static AdvisedRequest from(ChatClientRequest from) {
|
||||
Assert.notNull(from, "ChatClientRequest cannot be null");
|
||||
|
||||
List<Message> messages = new LinkedList<>(from.prompt().getInstructions());
|
||||
|
||||
Builder builder = new Builder();
|
||||
if (from.context().get(ChatClientAttributes.CHAT_MODEL.getKey()) instanceof ChatModel chatModel) {
|
||||
builder.chatModel = chatModel;
|
||||
}
|
||||
|
||||
if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof UserMessage userMessage) {
|
||||
builder.userText = userMessage.getText();
|
||||
builder.media = userMessage.getMedia();
|
||||
messages.remove(messages.size() - 1);
|
||||
}
|
||||
if (from.context().get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map<?, ?> contextUserParams) {
|
||||
builder.userParams = (Map<String, Object>) contextUserParams;
|
||||
}
|
||||
|
||||
if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof SystemMessage systemMessage) {
|
||||
builder.systemText = systemMessage.getText();
|
||||
messages.remove(messages.size() - 1);
|
||||
}
|
||||
if (from.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map<?, ?> contextSystemParams) {
|
||||
builder.systemParams = (Map<String, Object>) contextSystemParams;
|
||||
}
|
||||
|
||||
builder.messages = messages;
|
||||
|
||||
builder.chatOptions = Objects.requireNonNullElse(from.prompt().getOptions(), ChatOptions.builder().build());
|
||||
if (from.prompt().getOptions() instanceof ToolCallingChatOptions options) {
|
||||
builder.functionNames = options.getToolNames().stream().toList();
|
||||
builder.functionCallbacks = options.getToolCallbacks();
|
||||
builder.toolContext = options.getToolContext();
|
||||
}
|
||||
|
||||
if (from.context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List<?> advisors) {
|
||||
builder.advisors = (List<Advisor>) advisors;
|
||||
}
|
||||
builder.advisorParams = Map.of();
|
||||
builder.adviseContext = from.context();
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
|
||||
Assert.notNull(contextTransform, "contextTransform cannot be null");
|
||||
return from(this)
|
||||
@@ -146,6 +200,17 @@ public record AdvisedRequest(
|
||||
.build();
|
||||
}
|
||||
|
||||
public ChatClientRequest toChatClientRequest() {
|
||||
return ChatClientRequest.builder()
|
||||
.prompt(toPrompt())
|
||||
.context(this.adviseContext)
|
||||
.context(ChatClientAttributes.ADVISORS.getKey(), this.advisors)
|
||||
.context(ChatClientAttributes.CHAT_MODEL.getKey(), this.chatModel)
|
||||
.context(ChatClientAttributes.USER_PARAMS.getKey(), this.userParams)
|
||||
.context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), this.systemParams)
|
||||
.build();
|
||||
}
|
||||
|
||||
public Prompt toPrompt() {
|
||||
var messages = new ArrayList<>(this.messages());
|
||||
|
||||
@@ -157,16 +222,9 @@ public record AdvisedRequest(
|
||||
messages.add(new SystemMessage(processedSystemText));
|
||||
}
|
||||
|
||||
String formatParam = (String) this.adviseContext().get("formatParam");
|
||||
|
||||
var processedUserText = StringUtils.hasText(formatParam)
|
||||
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();
|
||||
|
||||
if (StringUtils.hasText(processedUserText)) {
|
||||
if (StringUtils.hasText(this.userText())) {
|
||||
Map<String, Object> userParams = new HashMap<>(this.userParams());
|
||||
if (StringUtils.hasText(formatParam)) {
|
||||
userParams.put("spring_ai_soc_format", formatParam);
|
||||
}
|
||||
String processedUserText = this.userText();
|
||||
if (!CollectionUtils.isEmpty(userParams)) {
|
||||
processedUserText = new PromptTemplate(processedUserText, userParams).render();
|
||||
}
|
||||
@@ -338,7 +396,9 @@ public record AdvisedRequest(
|
||||
* Set the advisor params.
|
||||
* @param advisorParams the advisor params
|
||||
* @return this {@link Builder} instance
|
||||
* @deprecated in favor of {@link #adviseContext(Map)}
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorParams(Map<String, Object> advisorParams) {
|
||||
this.advisorParams = advisorParams;
|
||||
return this;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -21,6 +21,7 @@ import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
@@ -33,8 +34,10 @@ import org.springframework.util.Assert;
|
||||
* @author Christian Tzolov
|
||||
* @author Thomas Vitale
|
||||
* @author Ilayaperumal Gopinathan
|
||||
* @deprecated Use {@link ChatClientResponse} instead.
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@Deprecated
|
||||
public record AdvisedResponse(@Nullable ChatResponse response, Map<String, Object> adviseContext) {
|
||||
|
||||
/**
|
||||
@@ -66,6 +69,15 @@ public record AdvisedResponse(@Nullable ChatResponse response, Map<String, Objec
|
||||
return new Builder().response(advisedResponse.response).adviseContext(advisedResponse.adviseContext);
|
||||
}
|
||||
|
||||
public static AdvisedResponse from(ChatClientResponse chatClientResponse) {
|
||||
Assert.notNull(chatClientResponse, "chatClientResponse cannot be null");
|
||||
return new AdvisedResponse(chatClientResponse.chatResponse(), chatClientResponse.context());
|
||||
}
|
||||
|
||||
public ChatClientResponse toChatClientResponse() {
|
||||
return new ChatClientResponse(this.response, this.adviseContext);
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the context of the advised response.
|
||||
* @param contextTransform the function to transform the context
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -24,9 +24,9 @@ import org.springframework.core.Ordered;
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @see CallAroundAdvisor
|
||||
* @see StreamAroundAdvisor
|
||||
* @see CallAroundAdvisorChain
|
||||
* @see CallAdvisor
|
||||
* @see StreamAdvisor
|
||||
* @see BaseAdvisor
|
||||
*/
|
||||
public interface Advisor extends Ordered {
|
||||
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
/**
|
||||
* A base interface for advisor chains that can be used to chain multiple advisors
|
||||
* together, both for call and stream advisors.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface BaseAdvisorChain extends CallAdvisorChain, StreamAdvisorChain {
|
||||
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
|
||||
/**
|
||||
* Advisor for execution flows ultimately resulting in a call to an AI model
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface CallAdvisor extends CallAroundAdvisor {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #adviseCall(ChatClientRequest, CallAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
|
||||
ChatClientResponse chatClientResponse = adviseCall(advisedRequest.toChatClientRequest(), chain);
|
||||
return AdvisedResponse.from(chatClientResponse);
|
||||
}
|
||||
|
||||
ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
|
||||
/**
|
||||
* A chain of {@link CallAdvisor} instances orchestrating the execution of a
|
||||
* {@link ChatClientRequest} on the next {@link CallAdvisor} in the chain.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface CallAdvisorChain extends CallAroundAdvisorChain {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #nextCall(ChatClientRequest)}
|
||||
*/
|
||||
@Deprecated
|
||||
default AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
|
||||
ChatClientResponse chatClientResponse = nextCall(advisedRequest.toChatClientRequest());
|
||||
return AdvisedResponse.from(chatClientResponse);
|
||||
}
|
||||
|
||||
ChatClientResponse nextCall(ChatClientRequest chatClientRequest);
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,14 +16,17 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
|
||||
/**
|
||||
* Around advisor that wraps the ChatModel#call(Prompt) method.
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link CallAdvisor}
|
||||
*/
|
||||
|
||||
@Deprecated
|
||||
public interface CallAroundAdvisor extends Advisor {
|
||||
|
||||
/**
|
||||
@@ -31,7 +34,10 @@ public interface CallAroundAdvisor extends Advisor {
|
||||
* @param advisedRequest the advised request
|
||||
* @param chain the advisor chain
|
||||
* @return the response
|
||||
* @deprecated in favor of
|
||||
* {@link CallAdvisor#adviseCall(ChatClientRequest, CallAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
|
||||
/**
|
||||
* The Call Around Advisor Chain is used to invoke the next Around Advisor in the chain.
|
||||
* Used for non-streaming responses.
|
||||
@@ -23,7 +25,9 @@ package org.springframework.ai.chat.client.advisor.api;
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link CallAdvisorChain}
|
||||
*/
|
||||
@Deprecated
|
||||
public interface CallAroundAdvisorChain {
|
||||
|
||||
/**
|
||||
@@ -32,7 +36,9 @@ public interface CallAroundAdvisorChain {
|
||||
* @param advisedRequest the request containing the data to be processed by the next
|
||||
* advisor in the chain.
|
||||
* @return the response generated by the next advisor in the chain.
|
||||
* @deprecated in favor of {@link CallAdvisorChain#nextCall(ChatClientRequest)}
|
||||
*/
|
||||
@Deprecated
|
||||
AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest);
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* Advisor for execution flows ultimately resulting in a streaming call to an AI model.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface StreamAdvisor extends StreamAroundAdvisor {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #adviseStream(ChatClientRequest, StreamAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
|
||||
Flux<ChatClientResponse> chatClientResponse = adviseStream(advisedRequest.toChatClientRequest(), chain);
|
||||
return chatClientResponse.map(AdvisedResponse::from);
|
||||
}
|
||||
|
||||
Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* A chain of {@link StreamAdvisor} instances orchestrating the execution of a
|
||||
* {@link ChatClientRequest} on the next {@link StreamAdvisor} in the chain.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public interface StreamAdvisorChain extends StreamAroundAdvisorChain {
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #nextStream(ChatClientRequest)}
|
||||
*/
|
||||
@Deprecated
|
||||
default Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
|
||||
Flux<ChatClientResponse> chatClientResponse = nextStream(advisedRequest.toChatClientRequest());
|
||||
return chatClientResponse.map(AdvisedResponse::from);
|
||||
}
|
||||
|
||||
Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest);
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
@@ -24,7 +25,9 @@ import reactor.core.publisher.Flux;
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link StreamAdvisor}
|
||||
*/
|
||||
@Deprecated
|
||||
public interface StreamAroundAdvisor extends Advisor {
|
||||
|
||||
/**
|
||||
@@ -32,7 +35,10 @@ public interface StreamAroundAdvisor extends Advisor {
|
||||
* @param advisedRequest the advised request
|
||||
* @param chain the chain of advisors to execute
|
||||
* @return the result of the advised request
|
||||
* @deprecated in favor of
|
||||
* {@link StreamAdvisor#adviseStream(ChatClientRequest, StreamAroundAdvisorChain)}
|
||||
*/
|
||||
@Deprecated
|
||||
Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
package org.springframework.ai.chat.client.advisor.api;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
@@ -25,7 +26,9 @@ import reactor.core.publisher.Flux;
|
||||
* @author Christian Tzolov
|
||||
* @author Dariusz Jedrzejczyk
|
||||
* @since 1.0.0
|
||||
* @deprecated in favor of {@link StreamAdvisorChain}
|
||||
*/
|
||||
@Deprecated
|
||||
public interface StreamAroundAdvisorChain {
|
||||
|
||||
/**
|
||||
@@ -34,7 +37,9 @@ public interface StreamAroundAdvisorChain {
|
||||
* @param advisedRequest the request containing data of the chat client that can be
|
||||
* modified before execution
|
||||
* @return a Flux stream of AdvisedResponse objects
|
||||
* @deprecated in favor of {@link StreamAdvisorChain#nextStream(ChatClientRequest)}
|
||||
*/
|
||||
@Deprecated
|
||||
Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -21,10 +21,13 @@ import java.util.Map;
|
||||
import io.micrometer.observation.Observation;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Context used to store metadata for chat client advisors.
|
||||
@@ -37,26 +40,12 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
|
||||
private final String advisorName;
|
||||
|
||||
private final Type advisorType;
|
||||
private final ChatClientRequest chatClientRequest;
|
||||
|
||||
/**
|
||||
* The order of the advisor in the advisor chain.
|
||||
*/
|
||||
private final int order;
|
||||
|
||||
/**
|
||||
* The {@link AdvisedRequest} data to be advised. Represents the row
|
||||
* {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}.
|
||||
*/
|
||||
@Nullable
|
||||
private AdvisedRequest advisorRequest;
|
||||
|
||||
/**
|
||||
* The shared data between the advisors in the chain. It is shared between all request
|
||||
* and response advising points of all advisors in the chain.
|
||||
*/
|
||||
@Nullable
|
||||
private Map<String, Object> advisorRequestContext;
|
||||
private ChatClientResponse chatClientResponse;
|
||||
|
||||
/**
|
||||
* the shared data between the advisors in the chain. It is shared between all request
|
||||
@@ -73,18 +62,32 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
* @param advisorRequestContext the shared data between the advisors in the chain
|
||||
* @param advisorResponseContext the shared data between the advisors in the chain
|
||||
* @param order the order of the advisor in the advisor chain
|
||||
* @deprecated use the builder instead
|
||||
*/
|
||||
@Deprecated
|
||||
public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable AdvisedRequest advisorRequest,
|
||||
@Nullable Map<String, Object> advisorRequestContext, @Nullable Map<String, Object> advisorResponseContext,
|
||||
int order) {
|
||||
Assert.hasText(advisorName, "advisorName must not be null or empty");
|
||||
Assert.notNull(advisorType, "advisorType must not be null");
|
||||
Assert.hasText(advisorName, "advisorName cannot be null or empty");
|
||||
|
||||
this.advisorName = advisorName;
|
||||
this.advisorType = advisorType;
|
||||
this.advisorRequest = advisorRequest;
|
||||
this.advisorRequestContext = advisorRequestContext;
|
||||
this.advisorResponseContext = advisorResponseContext;
|
||||
this.chatClientRequest = advisorRequest != null ? advisorRequest.toChatClientRequest()
|
||||
: ChatClientRequest.builder().prompt(new Prompt()).build();
|
||||
if (!CollectionUtils.isEmpty(advisorRequestContext)) {
|
||||
this.chatClientRequest.context().putAll(advisorRequestContext);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(advisorResponseContext)) {
|
||||
this.chatClientResponse = ChatClientResponse.builder().context(advisorResponseContext).build();
|
||||
}
|
||||
this.order = order;
|
||||
}
|
||||
|
||||
AdvisorObservationContext(String advisorName, ChatClientRequest chatClientRequest, int order) {
|
||||
Assert.hasText(advisorName, "advisorName cannot be null or empty");
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
|
||||
this.advisorName = advisorName;
|
||||
this.chatClientRequest = chatClientRequest;
|
||||
this.order = order;
|
||||
}
|
||||
|
||||
@@ -96,89 +99,115 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* The advisor name.
|
||||
* @return the advisor name
|
||||
*/
|
||||
public String getAdvisorName() {
|
||||
return this.advisorName;
|
||||
}
|
||||
|
||||
public ChatClientRequest getChatClientRequest() {
|
||||
return this.chatClientRequest;
|
||||
}
|
||||
|
||||
public int getOrder() {
|
||||
return this.order;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public ChatClientResponse getChatClientResponse() {
|
||||
return this.chatClientResponse;
|
||||
}
|
||||
|
||||
public void setChatClientResponse(@Nullable ChatClientResponse chatClientResponse) {
|
||||
this.chatClientResponse = chatClientResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* The type of the advisor.
|
||||
* @return the type of the advisor
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
public Type getAdvisorType() {
|
||||
return this.advisorType;
|
||||
return Type.AROUND;
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the advisor in the advisor chain.
|
||||
* @return the order of the advisor in the advisor chain
|
||||
* @deprecated not used anymore
|
||||
*/
|
||||
@Nullable
|
||||
@Deprecated
|
||||
public AdvisedRequest getAdvisedRequest() {
|
||||
return this.advisorRequest;
|
||||
return AdvisedRequest.from(this.chatClientRequest);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the {@link AdvisedRequest} data to be advised. Represents the row
|
||||
* {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}.
|
||||
* @param advisedRequest the advised request
|
||||
* @deprecated immutable object, use the builder instead to create a new instance
|
||||
*/
|
||||
@Deprecated
|
||||
public void setAdvisedRequest(@Nullable AdvisedRequest advisedRequest) {
|
||||
this.advisorRequest = advisedRequest;
|
||||
throw new IllegalStateException(
|
||||
"The AdvisedRequest is immutable. Build a new AdvisorObservationContext instead.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @return the shared data between the advisors in the chain
|
||||
* @deprecated use {@link #getChatClientRequest()} instead
|
||||
*/
|
||||
@Nullable
|
||||
@Deprecated
|
||||
public Map<String, Object> getAdvisorRequestContext() {
|
||||
return this.advisorRequestContext;
|
||||
return this.chatClientRequest.context();
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @param advisorRequestContext the shared data between the advisors in the chain
|
||||
* @deprecated not supported anymore, use {@link #getChatClientRequest()} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public void setAdvisorRequestContext(@Nullable Map<String, Object> advisorRequestContext) {
|
||||
this.advisorRequestContext = advisorRequestContext;
|
||||
if (!CollectionUtils.isEmpty(advisorRequestContext)) {
|
||||
this.chatClientRequest.context().putAll(advisorRequestContext);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @return the shared data between the advisors in the chain
|
||||
* @deprecated use {@link #getChatClientResponse()} instead
|
||||
*/
|
||||
@Nullable
|
||||
@Deprecated
|
||||
public Map<String, Object> getAdvisorResponseContext() {
|
||||
return this.advisorResponseContext;
|
||||
if (this.chatClientResponse != null) {
|
||||
return this.chatClientResponse.context();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the shared data between the advisors in the chain. It is shared between all
|
||||
* request and response advising points of all advisors in the chain.
|
||||
* @param advisorResponseContext the shared data between the advisors in the chain
|
||||
* @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public void setAdvisorResponseContext(@Nullable Map<String, Object> advisorResponseContext) {
|
||||
this.advisorResponseContext = advisorResponseContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* The order of the advisor in the advisor chain.
|
||||
* @return the order of the advisor in the advisor chain
|
||||
*/
|
||||
public int getOrder() {
|
||||
return this.order;
|
||||
}
|
||||
|
||||
/**
|
||||
* The type of the advisor.
|
||||
*
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
public enum Type {
|
||||
|
||||
/**
|
||||
@@ -203,7 +232,9 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
|
||||
private String advisorName;
|
||||
|
||||
private Type advisorType;
|
||||
private ChatClientRequest chatClientRequest;
|
||||
|
||||
private int order = 0;
|
||||
|
||||
private AdvisedRequest advisorRequest;
|
||||
|
||||
@@ -211,28 +242,32 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
|
||||
private Map<String, Object> advisorResponseContext;
|
||||
|
||||
private int order = 0;
|
||||
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advisor name.
|
||||
* @param advisorName the advisor name
|
||||
* @return the builder
|
||||
*/
|
||||
public Builder advisorName(String advisorName) {
|
||||
this.advisorName = advisorName;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder chatClientRequest(ChatClientRequest chatClientRequest) {
|
||||
this.chatClientRequest = chatClientRequest;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder order(int order) {
|
||||
this.order = order;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the advisor type.
|
||||
* @param advisorType the advisor type
|
||||
* @return the builder
|
||||
* @deprecated advisors don't have types anymore, they're all "around"
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorType(Type advisorType) {
|
||||
this.advisorType = advisorType;
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -240,7 +275,9 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
* Set the advised request.
|
||||
* @param advisedRequest the advised request
|
||||
* @return the builder
|
||||
* @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisedRequest(AdvisedRequest advisedRequest) {
|
||||
this.advisorRequest = advisedRequest;
|
||||
return this;
|
||||
@@ -250,7 +287,9 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
* Set the advisor request context.
|
||||
* @param advisorRequestContext the advisor request context
|
||||
* @return the builder
|
||||
* @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorRequestContext(Map<String, Object> advisorRequestContext) {
|
||||
this.advisorRequestContext = advisorRequestContext;
|
||||
return this;
|
||||
@@ -260,29 +299,26 @@ public class AdvisorObservationContext extends Observation.Context {
|
||||
* Set the advisor response context.
|
||||
* @param advisorResponseContext the advisor response context
|
||||
* @return the builder
|
||||
* @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder advisorResponseContext(Map<String, Object> advisorResponseContext) {
|
||||
this.advisorResponseContext = advisorResponseContext;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the order of the advisor in the advisor chain.
|
||||
* @param order the order of the advisor in the advisor chain
|
||||
* @return the builder
|
||||
*/
|
||||
public Builder order(int order) {
|
||||
this.order = order;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the {@link AdvisorObservationContext}.
|
||||
* @return the {@link AdvisorObservationContext}
|
||||
*/
|
||||
public AdvisorObservationContext build() {
|
||||
return new AdvisorObservationContext(this.advisorName, this.advisorType, this.advisorRequest,
|
||||
this.advisorRequestContext, this.advisorResponseContext, this.order);
|
||||
if (chatClientRequest != null && advisorRequest != null) {
|
||||
throw new IllegalArgumentException(
|
||||
"ChatClientRequest and AdvisedRequest cannot be set at the same time");
|
||||
}
|
||||
else if (chatClientRequest != null) {
|
||||
return new AdvisorObservationContext(this.advisorName, this.chatClientRequest, this.order);
|
||||
}
|
||||
else {
|
||||
return new AdvisorObservationContext(this.advisorName, Type.AROUND, this.advisorRequest,
|
||||
this.advisorRequestContext, this.advisorResponseContext, this.order);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -26,6 +26,7 @@ import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.ai.observation.conventions.SpringAiKind;
|
||||
import org.springframework.ai.util.ParsingUtils;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
/**
|
||||
* Default implementation of the {@link AdvisorObservationConvention}.
|
||||
@@ -55,6 +56,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
|
||||
@Override
|
||||
@Nullable
|
||||
public String getContextualName(AdvisorObservationContext context) {
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
return ParsingUtils.reConcatenateCamelCase(context.getAdvisorName(), "_")
|
||||
.replace("_around_advisor", "")
|
||||
.replace("_advisor", "");
|
||||
@@ -66,6 +68,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
|
||||
|
||||
@Override
|
||||
public KeyValues getLowCardinalityKeyValues(AdvisorObservationContext context) {
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
return KeyValues.of(aiOperationType(context), aiProvider(context), springAiKind(), advisorType(context),
|
||||
advisorName(context));
|
||||
}
|
||||
@@ -78,6 +81,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
|
||||
return KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER, AiProvider.SPRING_AI.value());
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
protected KeyValue advisorType(AdvisorObservationContext context) {
|
||||
return KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE, context.getAdvisorType().name());
|
||||
}
|
||||
@@ -96,6 +100,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
|
||||
|
||||
@Override
|
||||
public KeyValues getHighCardinalityKeyValues(AdvisorObservationContext context) {
|
||||
Assert.notNull(context, "context cannot be null");
|
||||
return KeyValues.of(advisorOrder(context));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -20,9 +20,15 @@ import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationFilter;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.observation.tracing.TracingHelper;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An {@link ObservationFilter} to include the chat prompt content in the observation.
|
||||
@@ -37,7 +43,8 @@ public class ChatClientInputContentObservationFilter implements ObservationFilte
|
||||
if (!(context instanceof ChatClientObservationContext chatClientObservationContext)) {
|
||||
return context;
|
||||
}
|
||||
|
||||
// TODO: we really want these? Should probably align with same format as chat
|
||||
// model observation
|
||||
chatClientSystemText(chatClientObservationContext);
|
||||
chatClientSystemParams(chatClientObservationContext);
|
||||
chatClientUserText(chatClientObservationContext);
|
||||
@@ -47,39 +54,65 @@ public class ChatClientInputContentObservationFilter implements ObservationFilte
|
||||
}
|
||||
|
||||
protected void chatClientSystemText(ChatClientObservationContext context) {
|
||||
if (!StringUtils.hasText(context.getRequest().getSystemText())) {
|
||||
List<Message> messages = context.getRequest().prompt().getInstructions();
|
||||
if (CollectionUtils.isEmpty(messages)) {
|
||||
return;
|
||||
}
|
||||
|
||||
var systemMessage = messages.stream()
|
||||
.filter(message -> message instanceof SystemMessage)
|
||||
.reduce((first, second) -> second);
|
||||
if (systemMessage.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_TEXT,
|
||||
context.getRequest().getSystemText()));
|
||||
systemMessage.get().getText()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
protected void chatClientSystemParams(ChatClientObservationContext context) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().getSystemParams())) {
|
||||
if (!(context.getRequest()
|
||||
.context()
|
||||
.get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map<?, ?> systemParams)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(systemParams)) {
|
||||
return;
|
||||
}
|
||||
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_PARAM,
|
||||
TracingHelper.concatenateMaps(context.getRequest().getSystemParams())));
|
||||
TracingHelper.concatenateMaps((Map<String, Object>) systemParams)));
|
||||
}
|
||||
|
||||
protected void chatClientUserText(ChatClientObservationContext context) {
|
||||
if (!StringUtils.hasText(context.getRequest().getUserText())) {
|
||||
List<Message> messages = context.getRequest().prompt().getInstructions();
|
||||
if (CollectionUtils.isEmpty(messages)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!(messages.get(messages.size() - 1) instanceof UserMessage userMessage)) {
|
||||
return;
|
||||
}
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT,
|
||||
context.getRequest().getUserText()));
|
||||
userMessage.getText()));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
protected void chatClientUserParams(ChatClientObservationContext context) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().getUserParams())) {
|
||||
if (!(context.getRequest()
|
||||
.context()
|
||||
.get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map<?, ?> userParams)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(userParams)) {
|
||||
return;
|
||||
}
|
||||
context.addHighCardinalityKeyValue(
|
||||
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_PARAMS,
|
||||
TracingHelper.concatenateMaps(context.getRequest().getUserParams())));
|
||||
TracingHelper.concatenateMaps((Map<String, Object>) userParams)));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -18,12 +18,14 @@ package org.springframework.ai.chat.client.observation;
|
||||
|
||||
import io.micrometer.observation.Observation;
|
||||
|
||||
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.observation.AiOperationMetadata;
|
||||
import org.springframework.ai.observation.conventions.AiOperationType;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* Context used to store metadata for chat client workflows.
|
||||
@@ -34,20 +36,16 @@ import org.springframework.util.Assert;
|
||||
*/
|
||||
public class ChatClientObservationContext extends Observation.Context {
|
||||
|
||||
private final DefaultChatClientRequestSpec request;
|
||||
private final ChatClientRequest request;
|
||||
|
||||
private final AiOperationMetadata operationMetadata = new AiOperationMetadata(AiOperationType.FRAMEWORK.value(),
|
||||
AiProvider.SPRING_AI.value());
|
||||
|
||||
private final boolean stream;
|
||||
|
||||
@Nullable
|
||||
private String format;
|
||||
|
||||
ChatClientObservationContext(DefaultChatClientRequestSpec requestSpec, String format, boolean isStream) {
|
||||
Assert.notNull(requestSpec, "requestSpec cannot be null");
|
||||
this.request = requestSpec;
|
||||
this.format = format;
|
||||
ChatClientObservationContext(ChatClientRequest chatClientRequest, boolean isStream) {
|
||||
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
|
||||
this.request = chatClientRequest;
|
||||
this.stream = isStream;
|
||||
}
|
||||
|
||||
@@ -55,7 +53,7 @@ public class ChatClientObservationContext extends Observation.Context {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public DefaultChatClientRequestSpec getRequest() {
|
||||
public ChatClientRequest getRequest() {
|
||||
return this.request;
|
||||
}
|
||||
|
||||
@@ -67,18 +65,31 @@ public class ChatClientObservationContext extends Observation.Context {
|
||||
return this.stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated not used anymore. The format instructions are already included in the
|
||||
* ChatModelObservationContext.
|
||||
*/
|
||||
@Nullable
|
||||
@Deprecated
|
||||
public String getFormat() {
|
||||
return this.format;
|
||||
if (this.request.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey()) instanceof String format) {
|
||||
return format;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated not used anymore. The format instructions are already included in the
|
||||
* ChatModelObservationContext.
|
||||
*/
|
||||
@Deprecated
|
||||
public void setFormat(@Nullable String format) {
|
||||
this.format = format;
|
||||
this.request.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format);
|
||||
}
|
||||
|
||||
public static final class Builder {
|
||||
|
||||
private DefaultChatClientRequestSpec request;
|
||||
private ChatClientRequest chatClientRequest;
|
||||
|
||||
private String format;
|
||||
|
||||
@@ -87,23 +98,41 @@ public class ChatClientObservationContext extends Observation.Context {
|
||||
private Builder() {
|
||||
}
|
||||
|
||||
public Builder withRequest(DefaultChatClientRequestSpec request) {
|
||||
this.request = request;
|
||||
public Builder request(ChatClientRequest chatClientRequest) {
|
||||
this.chatClientRequest = chatClientRequest;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated // use request(ChatClientRequest chatClientRequest)
|
||||
public Builder withRequest(ChatClientRequest chatClientRequest) {
|
||||
return request(chatClientRequest);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated not used anymore. The format instructions are already included in
|
||||
* the ChatModelObservationContext.
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder withFormat(String format) {
|
||||
this.format = format;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder withStream(boolean isStream) {
|
||||
public Builder stream(boolean isStream) {
|
||||
this.isStream = isStream;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Deprecated // use stream(boolean isStream)
|
||||
public Builder withStream(boolean isStream) {
|
||||
return stream(isStream);
|
||||
}
|
||||
|
||||
public ChatClientObservationContext build() {
|
||||
return new ChatClientObservationContext(this.request, this.format, this.isStream);
|
||||
if (StringUtils.hasText(format)) {
|
||||
this.chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format);
|
||||
}
|
||||
return new ChatClientObservationContext(this.chatClientRequest, this.isStream);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -19,15 +19,20 @@ package org.springframework.ai.chat.client.observation;
|
||||
import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.common.KeyValues;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.advisor.api.Advisor;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.observation.conventions.SpringAiKind;
|
||||
import org.springframework.ai.observation.tracing.TracingHelper;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Default conventions to populate observations for chat client workflows.
|
||||
*
|
||||
@@ -88,53 +93,71 @@ public class DefaultChatClientObservationConvention implements ChatClientObserva
|
||||
public KeyValues getHighCardinalityKeyValues(ChatClientObservationContext context) {
|
||||
var keyValues = KeyValues.empty();
|
||||
keyValues = chatClientAdvisorNames(keyValues, context);
|
||||
// TODO: rename attribute? any sensitive data here?
|
||||
keyValues = chatClientAdvisorParams(keyValues, context);
|
||||
keyValues = toolFunctionNames(keyValues, context);
|
||||
keyValues = toolFunctionCallbacks(keyValues, context);
|
||||
// TODO: remove this? Already included in chat model observation
|
||||
keyValues = toolNames(keyValues, context);
|
||||
// TODO: remove this? Already included in chat model observation
|
||||
keyValues = toolCallbacks(keyValues, context);
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
protected KeyValues chatClientAdvisorNames(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().getAdvisors())) {
|
||||
if (!(context.getRequest().context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List<?> advisors)) {
|
||||
return keyValues;
|
||||
}
|
||||
var advisorNames = context.getRequest().getAdvisors().stream().map(Advisor::getName).toList();
|
||||
var advisorNames = ((List<Advisor>) advisors).stream().map(Advisor::getName).toList();
|
||||
return keyValues.and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(),
|
||||
TracingHelper.concatenateStrings(advisorNames));
|
||||
}
|
||||
|
||||
protected KeyValues chatClientAdvisorParams(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().getAdvisorParams())) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().context())) {
|
||||
return keyValues;
|
||||
}
|
||||
var advisorParams = context.getRequest().getAdvisorParams();
|
||||
var chatClientContext = context.getRequest().context();
|
||||
Arrays.stream(ChatClientAttributes.values()).forEach(attribute -> chatClientContext.remove(attribute.getKey()));
|
||||
return keyValues.and(
|
||||
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(),
|
||||
TracingHelper.concatenateMaps(advisorParams));
|
||||
TracingHelper.concatenateMaps(chatClientContext));
|
||||
}
|
||||
|
||||
protected KeyValues toolFunctionNames(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().getFunctionNames())) {
|
||||
protected KeyValues toolNames(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (context.getRequest().prompt().getOptions() == null) {
|
||||
return keyValues;
|
||||
}
|
||||
var functionNames = context.getRequest().getFunctionNames();
|
||||
if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
var toolNames = options.getToolNames();
|
||||
if (CollectionUtils.isEmpty(toolNames)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
return keyValues.and(
|
||||
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(),
|
||||
TracingHelper.concatenateStrings(functionNames));
|
||||
TracingHelper.concatenateStrings(toolNames.stream().sorted().toList()));
|
||||
}
|
||||
|
||||
protected KeyValues toolFunctionCallbacks(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (CollectionUtils.isEmpty(context.getRequest().getFunctionCallbacks())) {
|
||||
protected KeyValues toolCallbacks(KeyValues keyValues, ChatClientObservationContext context) {
|
||||
if (context.getRequest().prompt().getOptions() == null) {
|
||||
return keyValues;
|
||||
}
|
||||
var functionCallbacks = context.getRequest()
|
||||
.getFunctionCallbacks()
|
||||
.stream()
|
||||
.map(FunctionCallback::getName)
|
||||
.toList();
|
||||
if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
var toolCallbacks = options.getToolCallbacks();
|
||||
if (CollectionUtils.isEmpty(toolCallbacks)) {
|
||||
return keyValues;
|
||||
}
|
||||
|
||||
var toolCallbackNames = toolCallbacks.stream().map(FunctionCallback::getName).sorted().toList();
|
||||
return keyValues
|
||||
.and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS
|
||||
.asString(), TracingHelper.concatenateStrings(functionCallbacks));
|
||||
.asString(), TracingHelper.concatenateStrings(toolCallbackNames));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link ChatClientRequest}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class ChatClientRequestTests {
|
||||
|
||||
@Test
|
||||
void whenPromptIsNullThenThrow() {
|
||||
assertThatThrownBy(() -> new ChatClientRequest(null, Map.of())).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("prompt cannot be null");
|
||||
|
||||
assertThatThrownBy(() -> ChatClientRequest.builder().prompt(null).context(Map.of()).build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("prompt cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenContextIsNullThenThrow() {
|
||||
assertThatThrownBy(() -> new ChatClientRequest(new Prompt(), null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("context cannot be null");
|
||||
|
||||
assertThatThrownBy(() -> ChatClientRequest.builder().prompt(new Prompt()).context(null).build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("context cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenContextHasNullKeysThenThrow() {
|
||||
Map<String, Object> context = new HashMap<>();
|
||||
context.put(null, "something");
|
||||
assertThatThrownBy(() -> new ChatClientRequest(new Prompt(), context))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("context keys cannot be null");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link ChatClientResponse}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class ChatClientResponseTests {
|
||||
|
||||
@Test
|
||||
void whenContextIsNullThenThrow() {
|
||||
assertThatThrownBy(() -> new ChatClientResponse(null, null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("context cannot be null");
|
||||
|
||||
assertThatThrownBy(() -> ChatClientResponse.builder().chatResponse(null).context(null).build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("context cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenContextHasNullKeysThenThrow() {
|
||||
Map<String, Object> context = new HashMap<>();
|
||||
context.put(null, "something");
|
||||
assertThatThrownBy(() -> new ChatClientResponse(null, context)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("context keys cannot be null");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -29,6 +29,9 @@ import java.util.function.Consumer;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
|
||||
@@ -46,7 +49,6 @@ import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.converter.ListOutputConverter;
|
||||
import org.springframework.ai.converter.StructuredOutputConverter;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.core.convert.support.DefaultConversionService;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
@@ -598,17 +600,67 @@ class DefaultChatClientTests {
|
||||
void buildCallResponseSpec() {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt();
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
.prompt("question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
assertThat(spec).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildCallResponseSpecWithNullRequest() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(null))
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(null, mock(BaseAdvisorChain.class),
|
||||
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("request cannot be null");
|
||||
.hasMessage("chatClientRequest cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildCallResponseSpecWithNullAdvisorChain() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class), null,
|
||||
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("advisorChain cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildCallResponseSpecWithNullObservationRegistry() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class),
|
||||
mock(BaseAdvisorChain.class), null, mock(ChatClientObservationConvention.class)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("observationRegistry cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildCallResponseSpecWithNullObservationConvention() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class),
|
||||
mock(BaseAdvisorChain.class), mock(ObservationRegistry.class), null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("observationConvention cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenSimplePromptThenChatClientResponse() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
given(chatModel.call(promptCaptor.capture()))
|
||||
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
|
||||
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ChatClientResponse chatClientResponse = spec.chatClientResponse();
|
||||
assertThat(chatClientResponse).isNotNull();
|
||||
|
||||
ChatResponse chatResponse = chatClientResponse.chatResponse();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response");
|
||||
|
||||
Prompt actualPrompt = promptCaptor.getValue();
|
||||
assertThat(actualPrompt.getInstructions()).hasSize(1);
|
||||
assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("my question");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -621,8 +673,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -644,8 +696,8 @@ class DefaultChatClientTests {
|
||||
Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question"));
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt(prompt);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -669,8 +721,8 @@ class DefaultChatClientTests {
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt(prompt)
|
||||
.user("another question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -696,8 +748,8 @@ class DefaultChatClientTests {
|
||||
.prompt()
|
||||
.user("another question")
|
||||
.messages(messages);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -719,8 +771,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse();
|
||||
assertThat(chatResponse).isNull();
|
||||
@@ -736,8 +788,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
String content = spec.content();
|
||||
assertThat(content).isNull();
|
||||
@@ -748,10 +800,10 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
assertThatThrownBy(() -> spec.responseEntity((ParameterizedTypeReference) null))
|
||||
assertThatThrownBy(() -> spec.responseEntity((ParameterizedTypeReference<?>) null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("type cannot be null");
|
||||
}
|
||||
@@ -766,8 +818,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ResponseEntity<ChatResponse, List<String>> responseEntity = spec
|
||||
.responseEntity(new ParameterizedTypeReference<>() {
|
||||
@@ -793,8 +845,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ResponseEntity<ChatResponse, List<Person>> responseEntity = spec
|
||||
.responseEntity(new ParameterizedTypeReference<>() {
|
||||
@@ -808,10 +860,10 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
assertThatThrownBy(() -> spec.responseEntity((StructuredOutputConverter) null))
|
||||
assertThatThrownBy(() -> spec.responseEntity((StructuredOutputConverter<?>) null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("structuredOutputConverter cannot be null");
|
||||
}
|
||||
@@ -826,8 +878,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ResponseEntity<ChatResponse, List<String>> responseEntity = spec
|
||||
.responseEntity(new ListOutputConverter(new DefaultConversionService()));
|
||||
@@ -847,8 +899,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ResponseEntity<ChatResponse, List<String>> responseEntity = spec
|
||||
.responseEntity(new ListOutputConverter(new DefaultConversionService()));
|
||||
@@ -861,8 +913,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
assertThatThrownBy(() -> spec.responseEntity((Class) null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("type cannot be null");
|
||||
@@ -878,8 +930,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ResponseEntity<ChatResponse, String> responseEntity = spec.responseEntity(String.class);
|
||||
assertThat(responseEntity.response()).isNotNull();
|
||||
@@ -898,8 +950,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
ResponseEntity<ChatResponse, Person> responseEntity = spec.responseEntity(Person.class);
|
||||
assertThat(responseEntity.response()).isNotNull();
|
||||
@@ -912,10 +964,10 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
assertThatThrownBy(() -> spec.entity((ParameterizedTypeReference) null))
|
||||
assertThatThrownBy(() -> spec.entity((ParameterizedTypeReference<?>) null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("type cannot be null");
|
||||
}
|
||||
@@ -930,8 +982,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
List<String> entity = spec.entity(new ParameterizedTypeReference<>() {
|
||||
});
|
||||
@@ -954,8 +1006,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
List<Person> entity = spec.entity(new ParameterizedTypeReference<>() {
|
||||
});
|
||||
@@ -967,10 +1019,10 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
assertThatThrownBy(() -> spec.entity((StructuredOutputConverter) null))
|
||||
assertThatThrownBy(() -> spec.entity((StructuredOutputConverter<?>) null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("structuredOutputConverter cannot be null");
|
||||
}
|
||||
@@ -980,8 +1032,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
List<String> entity = spec.entity(new ListOutputConverter(new DefaultConversionService()));
|
||||
assertThat(entity).isNull();
|
||||
@@ -999,8 +1051,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
List<String> entity = spec.entity(new ListOutputConverter(new DefaultConversionService()));
|
||||
assertThat(entity).hasSize(3);
|
||||
@@ -1011,10 +1063,10 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
assertThatThrownBy(() -> spec.entity((Class) null)).isInstanceOf(IllegalArgumentException.class)
|
||||
assertThatThrownBy(() -> spec.entity((Class<?>) null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("type cannot be null");
|
||||
}
|
||||
|
||||
@@ -1028,8 +1080,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
String entity = spec.entity(String.class);
|
||||
assertThat(entity).isNull();
|
||||
@@ -1047,8 +1099,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
|
||||
.call();
|
||||
|
||||
Person entity = spec.entity(Person.class);
|
||||
assertThat(entity).isNotNull();
|
||||
@@ -1061,17 +1113,67 @@ class DefaultChatClientTests {
|
||||
void buildStreamResponseSpec() {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt();
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
.prompt("question");
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
|
||||
.stream();
|
||||
assertThat(spec).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildStreamResponseSpecWithNullRequest() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(null))
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(null, mock(BaseAdvisorChain.class),
|
||||
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("request cannot be null");
|
||||
.hasMessage("chatClientRequest cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildStreamResponseSpecWithNullAdvisorChain() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class), null,
|
||||
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("advisorChain cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildStreamResponseSpecWithNullObservationRegistry() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class),
|
||||
mock(BaseAdvisorChain.class), null, mock(ChatClientObservationConvention.class)))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("observationRegistry cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void buildStreamResponseSpecWithNullObservationConvention() {
|
||||
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class),
|
||||
mock(BaseAdvisorChain.class), mock(ObservationRegistry.class), null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("observationConvention cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenSimplePromptThenFluxChatClientResponse() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
|
||||
given(chatModel.stream(promptCaptor.capture()))
|
||||
.willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))));
|
||||
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
|
||||
.stream();
|
||||
|
||||
ChatClientResponse chatClientResponse = spec.chatClientResponse().blockLast();
|
||||
assertThat(chatClientResponse).isNotNull();
|
||||
|
||||
ChatResponse chatResponse = chatClientResponse.chatResponse();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response");
|
||||
|
||||
Prompt actualPrompt = promptCaptor.getValue();
|
||||
assertThat(actualPrompt.getInstructions()).hasSize(1);
|
||||
assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("my question");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -1084,8 +1186,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
|
||||
.stream();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse().blockLast();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -1107,8 +1209,8 @@ class DefaultChatClientTests {
|
||||
Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question"));
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt(prompt);
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
|
||||
.stream();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse().blockLast();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -1132,8 +1234,8 @@ class DefaultChatClientTests {
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt(prompt)
|
||||
.user("another question");
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
|
||||
.stream();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse().blockLast();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -1159,8 +1261,9 @@ class DefaultChatClientTests {
|
||||
.prompt()
|
||||
.user("another question")
|
||||
.messages(messages);
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
|
||||
.stream();
|
||||
|
||||
ChatResponse chatResponse = spec.chatResponse().blockLast();
|
||||
assertThat(chatResponse).isNotNull();
|
||||
@@ -1183,8 +1286,8 @@ class DefaultChatClientTests {
|
||||
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
|
||||
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
|
||||
.prompt("my question");
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
|
||||
chatClientRequestSpec);
|
||||
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
|
||||
.stream();
|
||||
|
||||
String content = spec.content().blockLast();
|
||||
assertThat(content).isNull();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -21,7 +21,16 @@ import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.tool.ToolCallback;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
@@ -147,4 +156,58 @@ class AdvisedRequestTests {
|
||||
.hasMessage("toolContext cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenConvertToAndFromChatClientRequest() {
|
||||
ChatModel chatModel = mock(ChatModel.class);
|
||||
ChatOptions chatOptions = ToolCallingChatOptions.builder().build();
|
||||
List<Message> messages = List.of(mock(UserMessage.class));
|
||||
SystemMessage systemMessage = new SystemMessage("Instructions {key}");
|
||||
UserMessage userMessage = new UserMessage("Question {key}", mock(Media.class));
|
||||
Map<String, Object> systemParams = Map.of("key", "value");
|
||||
Map<String, Object> userParams = Map.of("key", "value");
|
||||
List<String> toolNames = List.of("tool1", "tool2");
|
||||
ToolCallback toolCallback = mock(ToolCallback.class);
|
||||
Map<String, Object> toolContext = Map.of("key", "value");
|
||||
List<Advisor> advisors = List.of(mock(Advisor.class));
|
||||
Map<String, Object> advisorContext = Map.of("key", "value");
|
||||
|
||||
AdvisedRequest advisedRequest = AdvisedRequest.builder()
|
||||
.chatModel(chatModel)
|
||||
.chatOptions(chatOptions)
|
||||
.messages(messages)
|
||||
.systemText(systemMessage.getText())
|
||||
.systemParams(systemParams)
|
||||
.userText(userMessage.getText())
|
||||
.userParams(userParams)
|
||||
.media(userMessage.getMedia())
|
||||
.functionNames(toolNames)
|
||||
.functionCallbacks(List.of(toolCallback))
|
||||
.toolContext(toolContext)
|
||||
.advisors(advisors)
|
||||
.adviseContext(advisorContext)
|
||||
.build();
|
||||
|
||||
ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest();
|
||||
|
||||
assertThat(chatClientRequest.context().get(ChatClientAttributes.CHAT_MODEL.getKey())).isEqualTo(chatModel);
|
||||
assertThat(chatClientRequest.prompt().getOptions()).isEqualTo(chatOptions);
|
||||
assertThat(chatClientRequest.prompt().getInstructions()).hasSize(3);
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(0)).isEqualTo(messages.get(0));
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(1).getText()).isEqualTo("Instructions value");
|
||||
assertThat(chatClientRequest.prompt().getInstructions().get(2).getText()).isEqualTo("Question value");
|
||||
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolNames())
|
||||
.containsAll(toolNames);
|
||||
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolCallbacks())
|
||||
.contains(toolCallback);
|
||||
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolContext())
|
||||
.containsAllEntriesOf(toolContext);
|
||||
assertThat((List<Advisor>) chatClientRequest.context().get(ChatClientAttributes.ADVISORS.getKey()))
|
||||
.containsAll(advisors);
|
||||
assertThat(chatClientRequest.context()).containsAllEntriesOf(advisorContext);
|
||||
|
||||
AdvisedRequest convertedAdvisedRequest = AdvisedRequest.from(chatClientRequest);
|
||||
assertThat(convertedAdvisedRequest.toPrompt()).isEqualTo(chatClientRequest.prompt());
|
||||
assertThat(convertedAdvisedRequest.adviseContext()).containsAllEntriesOf(chatClientRequest.context());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -21,6 +21,7 @@ import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClientResponse;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
@@ -67,7 +68,8 @@ class AdvisedResponseTests {
|
||||
|
||||
@Test
|
||||
void whenBuildFromNullAdvisedResponseThenThrows() {
|
||||
assertThatThrownBy(() -> AdvisedResponse.from(null)).isInstanceOf(IllegalArgumentException.class)
|
||||
assertThatThrownBy(() -> AdvisedResponse.from((AdvisedResponse) null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("advisedResponse cannot be null");
|
||||
}
|
||||
|
||||
@@ -85,4 +87,16 @@ class AdvisedResponseTests {
|
||||
.hasMessage("contextTransform cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenConvertToAndFromChatClientResponse() {
|
||||
ChatResponse chatResponse = mock(ChatResponse.class);
|
||||
Map<String, Object> context = Map.of("key", "value");
|
||||
AdvisedResponse advisedResponse = new AdvisedResponse(chatResponse, context);
|
||||
|
||||
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
|
||||
|
||||
AdvisedResponse newAdvisedResponse = AdvisedResponse.from(chatClientResponse);
|
||||
assertThat(newAdvisedResponse).isEqualTo(advisedResponse);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -17,9 +17,13 @@
|
||||
package org.springframework.ai.chat.client.advisor.observation;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link AdvisorObservationContext}.
|
||||
@@ -32,8 +36,7 @@ class AdvisorObservationContextTests {
|
||||
@Test
|
||||
void whenMandatoryOptionsThenReturn() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("MyName")
|
||||
.advisorType(AdvisorObservationContext.Type.BEFORE)
|
||||
.advisorName("AdvisorName")
|
||||
.build();
|
||||
|
||||
assertThat(observationContext).isNotNull();
|
||||
@@ -41,17 +44,38 @@ class AdvisorObservationContextTests {
|
||||
|
||||
@Test
|
||||
void missingAdvisorName() {
|
||||
assertThatThrownBy(
|
||||
() -> AdvisorObservationContext.builder().advisorType(AdvisorObservationContext.Type.BEFORE).build())
|
||||
assertThatThrownBy(() -> AdvisorObservationContext.builder().build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("advisorName must not be null or empty");
|
||||
.hasMessageContaining("advisorName cannot be null or empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
void missingAdvisorType() {
|
||||
assertThatThrownBy(() -> AdvisorObservationContext.builder().advisorName("MyName").build())
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("advisorType must not be null");
|
||||
void whenBuilderWithAdvisedRequestThenReturn() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("AdvisorName")
|
||||
.advisedRequest(mock(AdvisedRequest.class))
|
||||
.build();
|
||||
|
||||
assertThat(observationContext).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void whenBuilderWithChatClientRequestThenReturn() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("AdvisorName")
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build())
|
||||
.build();
|
||||
|
||||
assertThat(observationContext).isNotNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void missingBuilderWithBothRequestsThenThrow() {
|
||||
assertThatThrownBy(() -> AdvisorObservationContext.builder()
|
||||
.advisedRequest(mock(AdvisedRequest.class))
|
||||
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build())
|
||||
.build()).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("ChatClientRequest and AdvisedRequest cannot be set at the same time");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -47,7 +47,6 @@ class DefaultAdvisorObservationConventionTests {
|
||||
void contextualName() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("MyName")
|
||||
.advisorType(AdvisorObservationContext.Type.AROUND)
|
||||
.build();
|
||||
assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("my_name");
|
||||
}
|
||||
@@ -56,7 +55,6 @@ class DefaultAdvisorObservationConventionTests {
|
||||
void supportsAdvisorObservationContext() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("MyName")
|
||||
.advisorType(AdvisorObservationContext.Type.AROUND)
|
||||
.build();
|
||||
assertThat(this.observationConvention.supportsContext(observationContext)).isTrue();
|
||||
assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse();
|
||||
@@ -66,11 +64,8 @@ class DefaultAdvisorObservationConventionTests {
|
||||
void shouldHaveLowCardinalityKeyValuesWhenDefined() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("MyName")
|
||||
.advisorType(AdvisorObservationContext.Type.AROUND)
|
||||
.build();
|
||||
assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains(
|
||||
KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE.asString(),
|
||||
AdvisorObservationContext.Type.AROUND.name()),
|
||||
KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.FRAMEWORK.value()),
|
||||
KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.SPRING_AI.value()),
|
||||
KeyValue.of(LowCardinalityKeyNames.ADVISOR_NAME.asString(), "MyName"),
|
||||
@@ -81,7 +76,6 @@ class DefaultAdvisorObservationConventionTests {
|
||||
void shouldHaveKeyValuesWhenDefinedAndResponse() {
|
||||
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
|
||||
.advisorName("MyName")
|
||||
.advisorType(AdvisorObservationContext.Type.AROUND)
|
||||
.order(678)
|
||||
.build();
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,20 +16,22 @@
|
||||
|
||||
package org.springframework.ai.chat.client.observation;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@@ -57,14 +59,9 @@ class ChatClientInputContentObservationFilterTests {
|
||||
|
||||
@Test
|
||||
void whenEmptyInputContentThenReturnOriginalContext() {
|
||||
ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
ChatClientObservationConvention customObservationConvention = null;
|
||||
var request = ChatClientRequest.builder().prompt(new Prompt()).build();
|
||||
|
||||
var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
|
||||
List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention,
|
||||
Map.of());
|
||||
|
||||
var expectedContext = ChatClientObservationContext.builder().withRequest(request).build();
|
||||
var expectedContext = ChatClientObservationContext.builder().request(request).build();
|
||||
|
||||
var actualContext = this.observationFilter.map(expectedContext);
|
||||
|
||||
@@ -73,14 +70,13 @@ class ChatClientInputContentObservationFilterTests {
|
||||
|
||||
@Test
|
||||
void whenWithTextThenAugmentContext() {
|
||||
ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
|
||||
ChatClientObservationConvention customObservationConvention = null;
|
||||
var request = ChatClientRequest.builder()
|
||||
.prompt(new Prompt(new SystemMessage("sample system text"), new UserMessage("sample user text")))
|
||||
.context(ChatClientAttributes.USER_PARAMS.getKey(), Map.of("up1", "upv1"))
|
||||
.context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), Map.of("sp1", "sp1v"))
|
||||
.build();
|
||||
|
||||
var request = new DefaultChatClientRequestSpec(this.chatModel, "sample user text", Map.of("up1", "upv1"),
|
||||
"sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null,
|
||||
List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of());
|
||||
|
||||
var originalContext = ChatClientObservationContext.builder().withRequest(request).build();
|
||||
var originalContext = ChatClientObservationContext.builder().request(request).build();
|
||||
|
||||
var augmentedContext = this.observationFilter.map(originalContext);
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,17 +16,14 @@
|
||||
|
||||
package org.springframework.ai.chat.client.observation;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@@ -44,11 +41,10 @@ class ChatClientObservationContextTests {
|
||||
|
||||
@Test
|
||||
void whenMandatoryRequestOptionsThenReturn() {
|
||||
|
||||
var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
|
||||
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());
|
||||
|
||||
var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build();
|
||||
var observationContext = ChatClientObservationContext.builder()
|
||||
.request(ChatClientRequest.builder().prompt(new Prompt()).build())
|
||||
.stream(true)
|
||||
.build();
|
||||
|
||||
assertThat(observationContext).isNotNull();
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -17,18 +17,17 @@
|
||||
package org.springframework.ai.chat.client.observation;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import io.micrometer.common.KeyValue;
|
||||
import io.micrometer.observation.Observation;
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
|
||||
import org.springframework.ai.chat.client.ChatClientAttributes;
|
||||
import org.springframework.ai.chat.client.ChatClientRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
|
||||
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
|
||||
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
|
||||
@@ -36,7 +35,9 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.model.function.FunctionCallback;
|
||||
import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
||||
import org.springframework.ai.observation.conventions.AiProvider;
|
||||
import org.springframework.ai.observation.conventions.SpringAiKind;
|
||||
|
||||
@@ -56,7 +57,7 @@ class DefaultChatClientObservationConventionTests {
|
||||
@Mock
|
||||
ChatModel chatModel;
|
||||
|
||||
DefaultChatClientRequestSpec request;
|
||||
ChatClientRequest request;
|
||||
|
||||
static CallAroundAdvisor dummyAdvisor(String name) {
|
||||
return new CallAroundAdvisor() {
|
||||
@@ -109,8 +110,7 @@ class DefaultChatClientObservationConventionTests {
|
||||
|
||||
@BeforeEach
|
||||
public void beforeEach() {
|
||||
this.request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(),
|
||||
List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());
|
||||
this.request = ChatClientRequest.builder().prompt(new Prompt()).build();
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -121,8 +121,8 @@ class DefaultChatClientObservationConventionTests {
|
||||
@Test
|
||||
void shouldHaveContextualName() {
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.withRequest(this.request)
|
||||
.withStream(true)
|
||||
.request(this.request)
|
||||
.stream(true)
|
||||
.build();
|
||||
|
||||
assertThat(this.observationConvention.getContextualName(observationContext))
|
||||
@@ -132,8 +132,8 @@ class DefaultChatClientObservationConventionTests {
|
||||
@Test
|
||||
void supportsOnlyChatClientObservationContext() {
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.withRequest(this.request)
|
||||
.withStream(true)
|
||||
.request(this.request)
|
||||
.stream(true)
|
||||
.build();
|
||||
|
||||
assertThat(this.observationConvention.supportsContext(observationContext)).isTrue();
|
||||
@@ -143,8 +143,8 @@ class DefaultChatClientObservationConventionTests {
|
||||
@Test
|
||||
void shouldHaveRequiredKeyValues() {
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.withRequest(this.request)
|
||||
.withStream(true)
|
||||
.request(this.request)
|
||||
.stream(true)
|
||||
.build();
|
||||
|
||||
assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains(
|
||||
@@ -154,27 +154,31 @@ class DefaultChatClientObservationConventionTests {
|
||||
|
||||
@Test
|
||||
void shouldHaveOptionalKeyValues() {
|
||||
var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(),
|
||||
List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(),
|
||||
List.of("function1", "function2"), List.of(), null,
|
||||
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"),
|
||||
ObservationRegistry.NOOP, null, Map.of());
|
||||
var request = ChatClientRequest.builder()
|
||||
.prompt(new Prompt("",
|
||||
ToolCallingChatOptions.builder()
|
||||
.toolNames("tool1", "tool2")
|
||||
.toolCallbacks(dummyFunction("toolCallback1"), dummyFunction("toolCallback2"))
|
||||
.build()))
|
||||
.context("advParam1", "advisorParam1Value")
|
||||
.context(ChatClientAttributes.ADVISORS.getKey(),
|
||||
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")))
|
||||
.build();
|
||||
|
||||
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
|
||||
.withRequest(request)
|
||||
.request(request)
|
||||
.withFormat("json")
|
||||
.withStream(true)
|
||||
.stream(true)
|
||||
.build();
|
||||
|
||||
assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains(
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(),
|
||||
"[\"advisor1\", \"advisor2\", \"CallAroundAdvisor\", \"StreamAroundAdvisor\"]"),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(), "[\"advisor1\", \"advisor2\"]"),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(),
|
||||
"[\"advParam1\":\"advisorParam1Value\"]"),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(),
|
||||
"[\"function1\", \"function2\"]"),
|
||||
"[\"tool1\", \"tool2\"]"),
|
||||
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS.asString(),
|
||||
"[\"functionCallback1\", \"functionCallback2\"]"));
|
||||
"[\"toolCallback1\", \"toolCallback2\"]"));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -148,4 +148,21 @@ class PromptTests {
|
||||
assertThat(prompt.getInstructions()).isNotSameAs(copiedPrompt.getInstructions());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void mutatePrompt() {
|
||||
String template = "Hello, {name}! Your age is {age}.";
|
||||
Map<String, Object> model = new HashMap<>();
|
||||
model.put("name", "Alice");
|
||||
model.put("age", 30);
|
||||
PromptTemplate promptTemplate = new PromptTemplate(template, model);
|
||||
ChatOptions chatOptions = ChatOptions.builder().temperature(0.5).maxTokens(100).build();
|
||||
|
||||
Prompt prompt = promptTemplate.create(model, chatOptions);
|
||||
|
||||
Prompt copiedPrompt = prompt.mutate().build();
|
||||
assertThat(prompt).isNotSameAs(copiedPrompt);
|
||||
assertThat(prompt.getOptions()).isNotSameAs(copiedPrompt.getOptions());
|
||||
assertThat(prompt.getInstructions()).isNotSameAs(copiedPrompt.getInstructions());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StreamUtils;
|
||||
|
||||
@@ -49,6 +50,7 @@ public abstract class AbstractMessage implements Message {
|
||||
/**
|
||||
* The content of the message.
|
||||
*/
|
||||
@Nullable
|
||||
protected final String textContent;
|
||||
|
||||
/**
|
||||
@@ -63,11 +65,12 @@ public abstract class AbstractMessage implements Message {
|
||||
* @param textContent the text content
|
||||
* @param metadata the metadata
|
||||
*/
|
||||
protected AbstractMessage(MessageType messageType, String textContent, Map<String, Object> metadata) {
|
||||
protected AbstractMessage(MessageType messageType, @Nullable String textContent, Map<String, Object> metadata) {
|
||||
Assert.notNull(messageType, "Message type must not be null");
|
||||
if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) {
|
||||
Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages");
|
||||
}
|
||||
Assert.notNull(metadata, "Metadata must not be null");
|
||||
this.messageType = messageType;
|
||||
this.textContent = textContent;
|
||||
this.metadata = new HashMap<>(metadata);
|
||||
@@ -81,7 +84,9 @@ public abstract class AbstractMessage implements Message {
|
||||
* @param metadata the metadata
|
||||
*/
|
||||
protected AbstractMessage(MessageType messageType, Resource resource, Map<String, Object> metadata) {
|
||||
Assert.notNull(messageType, "Message type must not be null");
|
||||
Assert.notNull(resource, "Resource must not be null");
|
||||
Assert.notNull(metadata, "Metadata must not be null");
|
||||
try (InputStream inputStream = resource.getInputStream()) {
|
||||
this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||
}
|
||||
@@ -98,6 +103,7 @@ public abstract class AbstractMessage implements Message {
|
||||
* @return the content of the message
|
||||
*/
|
||||
@Override
|
||||
@Nullable
|
||||
public String getText() {
|
||||
return this.textContent;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.messages;
|
||||
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StreamUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.Charset;
|
||||
|
||||
/**
|
||||
* Utility class for managing messages.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
final class MessageUtils {
|
||||
|
||||
private MessageUtils() {
|
||||
}
|
||||
|
||||
static String readResource(Resource resource) {
|
||||
return readResource(resource, Charset.defaultCharset());
|
||||
}
|
||||
|
||||
static String readResource(Resource resource, Charset charset) {
|
||||
Assert.notNull(resource, "resource cannot be null");
|
||||
Assert.notNull(charset, "charset cannot be null");
|
||||
try (InputStream inputStream = resource.getInputStream()) {
|
||||
return StreamUtils.copyToString(inputStream, charset);
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new RuntimeException("Failed to read resource", ex);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,10 +16,14 @@
|
||||
|
||||
package org.springframework.ai.chat.messages;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.lang.NonNull;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* A message of the type 'system' passed as input. The system message gives high level
|
||||
@@ -31,14 +35,19 @@ import org.springframework.core.io.Resource;
|
||||
public class SystemMessage extends AbstractMessage {
|
||||
|
||||
public SystemMessage(String textContent) {
|
||||
super(MessageType.SYSTEM, textContent, Map.of());
|
||||
this(textContent, Map.of());
|
||||
}
|
||||
|
||||
public SystemMessage(Resource resource) {
|
||||
super(MessageType.SYSTEM, resource, Map.of());
|
||||
this(MessageUtils.readResource(resource), Map.of());
|
||||
}
|
||||
|
||||
private SystemMessage(String textContent, Map<String, Object> metadata) {
|
||||
super(MessageType.SYSTEM, textContent, metadata);
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public String getText() {
|
||||
return this.textContent;
|
||||
}
|
||||
@@ -68,4 +77,53 @@ public class SystemMessage extends AbstractMessage {
|
||||
+ ", metadata=" + this.metadata + '}';
|
||||
}
|
||||
|
||||
public SystemMessage copy() {
|
||||
return new SystemMessage(getText(), Map.copyOf(this.metadata));
|
||||
}
|
||||
|
||||
public Builder mutate() {
|
||||
return new Builder().text(this.textContent).metadata(this.metadata);
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
@Nullable
|
||||
private String textContent;
|
||||
|
||||
@Nullable
|
||||
private Resource resource;
|
||||
|
||||
private Map<String, Object> metadata = new HashMap<>();
|
||||
|
||||
public Builder text(String textContent) {
|
||||
this.textContent = textContent;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder text(Resource resource) {
|
||||
this.resource = resource;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder metadata(Map<String, Object> metadata) {
|
||||
this.metadata = metadata;
|
||||
return this;
|
||||
}
|
||||
|
||||
public SystemMessage build() {
|
||||
if (StringUtils.hasText(textContent) && resource != null) {
|
||||
throw new IllegalArgumentException("textContent and resource cannot be set at the same time");
|
||||
}
|
||||
else if (resource != null) {
|
||||
this.textContent = MessageUtils.readResource(resource);
|
||||
}
|
||||
return new SystemMessage(this.textContent, this.metadata);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -19,13 +19,17 @@ package org.springframework.ai.chat.messages;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.ai.content.MediaContent;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.lang.NonNull;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* A message of the type 'user' passed as input Messages with the user role are from the
|
||||
@@ -37,26 +41,45 @@ public class UserMessage extends AbstractMessage implements MediaContent {
|
||||
protected final List<Media> media;
|
||||
|
||||
public UserMessage(String textContent) {
|
||||
this(MessageType.USER, textContent, new ArrayList<>(), Map.of());
|
||||
this(textContent, new ArrayList<>(), Map.of());
|
||||
}
|
||||
|
||||
public UserMessage(Resource resource) {
|
||||
super(MessageType.USER, resource, Map.of());
|
||||
this.media = new ArrayList<>();
|
||||
this(MessageUtils.readResource(resource));
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #builder()} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public UserMessage(String textContent, List<Media> media) {
|
||||
this(MessageType.USER, textContent, media, Map.of());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #builder()} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public UserMessage(String textContent, Media... media) {
|
||||
this(textContent, Arrays.asList(media));
|
||||
}
|
||||
|
||||
public UserMessage(String textContent, Collection<Media> mediaList, Map<String, Object> metadata) {
|
||||
this(MessageType.USER, textContent, mediaList, metadata);
|
||||
/**
|
||||
* @deprecated use {@link #builder()} instead. Will be made private in the next
|
||||
* release.
|
||||
*/
|
||||
@Deprecated
|
||||
public UserMessage(String textContent, Collection<Media> media, Map<String, Object> metadata) {
|
||||
super(MessageType.USER, textContent, metadata);
|
||||
Assert.notNull(media, "media cannot be null");
|
||||
Assert.noNullElements(media, "media cannot have null elements");
|
||||
this.media = new ArrayList<>(media);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated use {@link #builder()} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public UserMessage(MessageType messageType, String textContent, Collection<Media> media,
|
||||
Map<String, Object> metadata) {
|
||||
super(messageType, textContent, metadata);
|
||||
@@ -71,13 +94,77 @@ public class UserMessage extends AbstractMessage implements MediaContent {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Media> getMedia() {
|
||||
return this.media;
|
||||
}
|
||||
|
||||
@Override
|
||||
@NonNull
|
||||
public String getText() {
|
||||
return this.textContent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Media> getMedia() {
|
||||
return this.media;
|
||||
}
|
||||
|
||||
public UserMessage copy() {
|
||||
return new UserMessage(getText(), List.copyOf(getMedia()), Map.copyOf(getMetadata()));
|
||||
}
|
||||
|
||||
public Builder mutate() {
|
||||
return new Builder().text(getText()).media(List.copyOf(getMedia())).metadata(Map.copyOf(getMetadata()));
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
@Nullable
|
||||
private String textContent;
|
||||
|
||||
@Nullable
|
||||
private Resource resource;
|
||||
|
||||
private List<Media> media = new ArrayList<>();
|
||||
|
||||
private Map<String, Object> metadata = new HashMap<>();
|
||||
|
||||
public Builder text(String textContent) {
|
||||
this.textContent = textContent;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder text(Resource resource) {
|
||||
this.resource = resource;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder media(List<Media> media) {
|
||||
this.media = media;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder media(@Nullable Media... media) {
|
||||
if (media != null) {
|
||||
this.media = Arrays.asList(media);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder metadata(Map<String, Object> metadata) {
|
||||
this.metadata = metadata;
|
||||
return this;
|
||||
}
|
||||
|
||||
public UserMessage build() {
|
||||
if (StringUtils.hasText(textContent) && resource != null) {
|
||||
throw new IllegalArgumentException("textContent and resource cannot be set at the same time");
|
||||
}
|
||||
else if (resource != null) {
|
||||
this.textContent = MessageUtils.readResource(resource);
|
||||
}
|
||||
return new UserMessage(this.textContent, this.media, this.metadata);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
@NonNullApi
|
||||
@NonNullFields
|
||||
package org.springframework.ai.chat.messages;
|
||||
|
||||
import org.springframework.lang.NonNullApi;
|
||||
import org.springframework.lang.NonNullFields;
|
||||
@@ -30,6 +30,8 @@ import org.springframework.ai.chat.messages.ToolResponseMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.model.ModelRequest;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* The Prompt class represents a prompt used in AI model requests. A prompt consists of
|
||||
@@ -62,15 +64,15 @@ public class Prompt implements ModelRequest<List<Message>> {
|
||||
this(Arrays.asList(messages), null);
|
||||
}
|
||||
|
||||
public Prompt(String contents, ChatOptions chatOptions) {
|
||||
public Prompt(String contents, @Nullable ChatOptions chatOptions) {
|
||||
this(new UserMessage(contents), chatOptions);
|
||||
}
|
||||
|
||||
public Prompt(Message message, ChatOptions chatOptions) {
|
||||
public Prompt(Message message, @Nullable ChatOptions chatOptions) {
|
||||
this(Collections.singletonList(message), chatOptions);
|
||||
}
|
||||
|
||||
public Prompt(List<Message> messages, ChatOptions chatOptions) {
|
||||
public Prompt(List<Message> messages, @Nullable ChatOptions chatOptions) {
|
||||
this.messages = messages;
|
||||
this.chatOptions = chatOptions;
|
||||
}
|
||||
@@ -123,10 +125,17 @@ public class Prompt implements ModelRequest<List<Message>> {
|
||||
List<Message> messagesCopy = new ArrayList<>();
|
||||
this.messages.forEach(message -> {
|
||||
if (message instanceof UserMessage userMessage) {
|
||||
messagesCopy.add(new UserMessage(userMessage.getText(), userMessage.getMedia(), message.getMetadata()));
|
||||
messagesCopy.add(UserMessage.builder()
|
||||
.text(userMessage.getText())
|
||||
.media(userMessage.getMedia())
|
||||
.metadata(message.getMetadata())
|
||||
.build());
|
||||
}
|
||||
else if (message instanceof SystemMessage systemMessage) {
|
||||
messagesCopy.add(new SystemMessage(systemMessage.getText()));
|
||||
messagesCopy.add(SystemMessage.builder()
|
||||
.text(systemMessage.getText())
|
||||
.metadata(systemMessage.getMetadata())
|
||||
.build());
|
||||
}
|
||||
else if (message instanceof AssistantMessage assistantMessage) {
|
||||
messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
|
||||
@@ -144,4 +153,69 @@ public class Prompt implements ModelRequest<List<Message>> {
|
||||
return messagesCopy;
|
||||
}
|
||||
|
||||
public Builder mutate() {
|
||||
Builder builder = new Builder().messages(instructionsCopy());
|
||||
if (this.chatOptions != null) {
|
||||
builder.chatOptions(this.chatOptions.copy());
|
||||
}
|
||||
return builder;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
@Nullable
|
||||
private String content;
|
||||
|
||||
@Nullable
|
||||
private List<Message> messages = new ArrayList<>();
|
||||
|
||||
@Nullable
|
||||
private ChatOptions chatOptions;
|
||||
|
||||
public Builder content(@Nullable String content) {
|
||||
this.content = content;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder messages(Message... messages) {
|
||||
if (messages != null) {
|
||||
this.messages = Arrays.asList(messages);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder messages(List<Message> messages) {
|
||||
this.messages = messages;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder addMessage(Message message) {
|
||||
if (this.messages == null) {
|
||||
this.messages = new ArrayList<>();
|
||||
}
|
||||
this.messages.add(message);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder chatOptions(ChatOptions chatOptions) {
|
||||
this.chatOptions = chatOptions;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Prompt build() {
|
||||
if (StringUtils.hasText(this.content) && !CollectionUtils.isEmpty(this.messages)) {
|
||||
throw new IllegalArgumentException("content and messages cannot be set at the same time");
|
||||
}
|
||||
else if (StringUtils.hasText(this.content)) {
|
||||
this.messages = List.of(new UserMessage(this.content));
|
||||
}
|
||||
return new Prompt(this.messages, this.chatOptions);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.messages;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link MessageUtils}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class MessageUtilsTests {
|
||||
|
||||
@Test
|
||||
void readResource() {
|
||||
String content = MessageUtils.readResource(new ClassPathResource("prompt-user.txt"));
|
||||
assertThat(content).isEqualTo("Hello, world!");
|
||||
}
|
||||
|
||||
@Test
|
||||
void readResourceWhenNull() {
|
||||
assertThatThrownBy(() -> MessageUtils.readResource(null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("resource cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
void readResourceWithCharset() {
|
||||
String content = MessageUtils.readResource(new ClassPathResource("prompt-user.txt"), StandardCharsets.UTF_8);
|
||||
assertThat(content).isEqualTo("Hello, world!");
|
||||
}
|
||||
|
||||
@Test
|
||||
void readResourceWithCharsetWhenNull() {
|
||||
assertThatThrownBy(() -> MessageUtils.readResource(new ClassPathResource("prompt-user.txt"), null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("charset cannot be null");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.messages;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
import static org.springframework.ai.chat.messages.AbstractMessage.MESSAGE_TYPE;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link SystemMessage}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class SystemMessageTests {
|
||||
|
||||
@Test
|
||||
void systemMessageWithNullText() {
|
||||
assertThrows(IllegalArgumentException.class, () -> new SystemMessage((String) null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void systemMessageWithTextContent() {
|
||||
String text = "Tell me, did you sail across the sun?";
|
||||
SystemMessage message = new SystemMessage(text);
|
||||
assertEquals(text, message.getText());
|
||||
assertEquals(MessageType.SYSTEM, message.getMetadata().get(MESSAGE_TYPE));
|
||||
}
|
||||
|
||||
@Test
|
||||
void systemMessageWithNullResource() {
|
||||
assertThrows(IllegalArgumentException.class, () -> new SystemMessage((Resource) null));
|
||||
}
|
||||
|
||||
@Test
|
||||
void systemMessageWithResource() {
|
||||
SystemMessage message = new SystemMessage(new ClassPathResource("prompt-system.txt"));
|
||||
assertEquals("Tell me, did you sail across the sun?", message.getText());
|
||||
assertEquals(MessageType.SYSTEM, message.getMetadata().get(MESSAGE_TYPE));
|
||||
}
|
||||
|
||||
@Test
|
||||
void systemMessageFromBuilderWithText() {
|
||||
String text = "Tell me, did you sail across the sun?";
|
||||
SystemMessage message = SystemMessage.builder().text(text).metadata(Map.of("key", "value")).build();
|
||||
assertEquals(text, message.getText());
|
||||
assertThat(message.getMetadata()).hasSize(2)
|
||||
.containsEntry(MESSAGE_TYPE, MessageType.SYSTEM)
|
||||
.containsEntry("key", "value");
|
||||
}
|
||||
|
||||
@Test
|
||||
void systemMessageFromBuilderWithResource() {
|
||||
Resource resource = new ClassPathResource("prompt-system.txt");
|
||||
SystemMessage message = SystemMessage.builder().text(resource).metadata(Map.of("key", "value")).build();
|
||||
assertEquals("Tell me, did you sail across the sun?", message.getText());
|
||||
assertThat(message.getMetadata()).hasSize(2)
|
||||
.containsEntry(MESSAGE_TYPE, MessageType.SYSTEM)
|
||||
.containsEntry("key", "value");
|
||||
}
|
||||
|
||||
@Test
|
||||
void systemMessageCopy() {
|
||||
String text1 = "Tell me, did you sail across the sun?";
|
||||
Map<String, Object> metadata1 = Map.of("key", "value");
|
||||
SystemMessage systemMessage1 = SystemMessage.builder().text(text1).metadata(metadata1).build();
|
||||
|
||||
SystemMessage systemMessage2 = systemMessage1.copy();
|
||||
|
||||
assertThat(systemMessage2.getText()).isEqualTo(text1);
|
||||
assertThat(systemMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void systemMessageMutate() {
|
||||
String text1 = "Tell me, did you sail across the sun?";
|
||||
Map<String, Object> metadata1 = Map.of("key", "value");
|
||||
SystemMessage systemMessage1 = SystemMessage.builder().text(text1).metadata(metadata1).build();
|
||||
|
||||
SystemMessage systemMessage2 = systemMessage1.mutate().build();
|
||||
|
||||
assertThat(systemMessage2.getText()).isEqualTo(text1);
|
||||
assertThat(systemMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
|
||||
|
||||
String text3 = "Farewell, Aragog!";
|
||||
SystemMessage systemMessage3 = systemMessage2.mutate().text(text3).build();
|
||||
|
||||
assertThat(systemMessage3.getText()).isEqualTo(text3);
|
||||
assertThat(systemMessage3.getMetadata()).hasSize(2).isNotSameAs(systemMessage2.getMetadata());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
/*
|
||||
* Copyright 2023-2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.messages;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.content.Media;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.springframework.ai.chat.messages.AbstractMessage.MESSAGE_TYPE;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link UserMessage}.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
*/
|
||||
class UserMessageTests {
|
||||
|
||||
@Test
|
||||
void userMessageWithNullText() {
|
||||
assertThatThrownBy(() -> new UserMessage((String) null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("Content must not be null for SYSTEM or USER messages");
|
||||
;
|
||||
}
|
||||
|
||||
@Test
|
||||
void userMessageWithTextContent() {
|
||||
String text = "Hello, world!";
|
||||
UserMessage message = new UserMessage(text);
|
||||
assertThat(message.getText()).isEqualTo(text);
|
||||
assertThat(message.getMedia()).isEmpty();
|
||||
assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER);
|
||||
}
|
||||
|
||||
@Test
|
||||
void userMessageWithNullResource() {
|
||||
assertThatThrownBy(() -> new UserMessage((Resource) null)).isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("resource cannot be null");
|
||||
;
|
||||
}
|
||||
|
||||
@Test
|
||||
void userMessageWithResource() {
|
||||
UserMessage message = new UserMessage(new ClassPathResource("prompt-user.txt"));
|
||||
assertThat(message.getText()).isEqualTo("Hello, world!");
|
||||
assertThat(message.getMedia()).isEmpty();
|
||||
assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER);
|
||||
}
|
||||
|
||||
@Test
|
||||
void userMessageFromBuilderWithText() {
|
||||
String text = "Hello, world!";
|
||||
UserMessage message = UserMessage.builder()
|
||||
.text(text)
|
||||
.media(new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt")))
|
||||
.metadata(Map.of("key", "value"))
|
||||
.build();
|
||||
assertThat(message.getText()).isEqualTo(text);
|
||||
assertThat(message.getMedia()).hasSize(1);
|
||||
assertThat(message.getMetadata()).hasSize(2)
|
||||
.containsEntry(MESSAGE_TYPE, MessageType.USER)
|
||||
.containsEntry("key", "value");
|
||||
}
|
||||
|
||||
@Test
|
||||
void userMessageFromBuilderWithResource() {
|
||||
UserMessage message = UserMessage.builder().text(new ClassPathResource("prompt-user.txt")).build();
|
||||
assertThat(message.getText()).isEqualTo("Hello, world!");
|
||||
assertThat(message.getMedia()).isEmpty();
|
||||
assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER);
|
||||
}
|
||||
|
||||
@Test
|
||||
void userMessageCopy() {
|
||||
String text1 = "Hello, world!";
|
||||
Media media1 = new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt"));
|
||||
Map<String, Object> metadata1 = Map.of("key", "value");
|
||||
UserMessage userMessage1 = UserMessage.builder().text(text1).media(media1).metadata(metadata1).build();
|
||||
|
||||
UserMessage userMessage2 = userMessage1.copy();
|
||||
|
||||
assertThat(userMessage2.getText()).isEqualTo(text1);
|
||||
assertThat(userMessage2.getMedia()).hasSize(1).isNotSameAs(metadata1);
|
||||
assertThat(userMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void userMessageMutate() {
|
||||
String text1 = "Hello, world!";
|
||||
Media media1 = new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt"));
|
||||
Map<String, Object> metadata1 = Map.of("key", "value");
|
||||
UserMessage userMessage1 = UserMessage.builder().text(text1).media(media1).metadata(metadata1).build();
|
||||
|
||||
UserMessage userMessage2 = userMessage1.mutate().build();
|
||||
|
||||
assertThat(userMessage2.getText()).isEqualTo(text1);
|
||||
assertThat(userMessage2.getMedia()).hasSize(1).isNotSameAs(metadata1);
|
||||
assertThat(userMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
|
||||
|
||||
String text3 = "Farewell, Aragog!";
|
||||
UserMessage userMessage3 = userMessage2.mutate().text(text3).build();
|
||||
|
||||
assertThat(userMessage3.getText()).isEqualTo(text3);
|
||||
assertThat(userMessage3.getMedia()).hasSize(1).isNotSameAs(metadata1);
|
||||
assertThat(userMessage3.getMetadata()).hasSize(2).isNotSameAs(metadata1);
|
||||
}
|
||||
|
||||
}
|
||||
1
spring-ai-model/src/test/resources/prompt-system.txt
Normal file
1
spring-ai-model/src/test/resources/prompt-system.txt
Normal file
@@ -0,0 +1 @@
|
||||
Tell me, did you sail across the sun?
|
||||
1
spring-ai-model/src/test/resources/prompt-user.txt
Normal file
1
spring-ai-model/src/test/resources/prompt-user.txt
Normal file
@@ -0,0 +1 @@
|
||||
Hello, world!
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright 2025 - 2025 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.ai.test;
|
||||
|
||||
/**
|
||||
* Utility class for escaping curly brackets in strings
|
||||
*
|
||||
* @author Christian Tzolov
|
||||
*
|
||||
*/
|
||||
public class CurlyBracketEscaper {
|
||||
|
||||
/**
|
||||
* Escapes all curly brackets in the input string by adding a backslash before them
|
||||
* @param input The string containing curly brackets to escape
|
||||
* @return The string with escaped curly brackets
|
||||
*/
|
||||
public static String escapeCurlyBrackets(String input) {
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
return input.replace("{", "\\{").replace("}", "\\}");
|
||||
}
|
||||
|
||||
/**
|
||||
* Unescapes previously escaped curly brackets by removing the backslashes
|
||||
* @param input The string containing escaped curly brackets
|
||||
* @return The string with unescaped curly brackets
|
||||
*/
|
||||
public static String unescapeCurlyBrackets(String input) {
|
||||
if (input == null) {
|
||||
return null;
|
||||
}
|
||||
return input.replace("\\{", "{").replace("\\}", "}");
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user