Make ChatClient and Advisor APIs more robust - Part 1

- Introduce “ChatClientRequest” and “ChatClientResponse” for propagating requests/responses in a ChatClient advisor chain.
- Structure a Prompt at the beginning of the chain, to ensure a consistent view across execution chain and observations. Any template is rendered at the beginning so that every advisor doesn’t have to do it again.
- Improve observations to include the complete view of the prompt messages, instead of only considering userText and systemText.
- Remove legacy “around” advisor type concept.
- Keep backward compatibility for AdvisedRequest, AdvisedResponse, and legacy Advisor APIs.

Relates to gh-2655

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
This commit is contained in:
Thomas Vitale
2025-04-15 20:26:00 +02:00
committed by Christian Tzolov
parent 593083980b
commit 1f59ccadad
53 changed files with 2365 additions and 487 deletions

View File

@@ -42,6 +42,7 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.test.CurlyBracketEscaper;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
@@ -189,7 +190,7 @@ class AnthropicChatClientIT {
.user(u -> u
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ "{format}")
.param("format", outputConverter.getFormat()))
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
.stream()
.content();

View File

@@ -31,6 +31,7 @@ import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.test.CurlyBracketEscaper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringBootConfiguration;
@@ -83,7 +84,7 @@ public class AzureOpenAiChatClientIT {
.user(u -> u
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ "{format}")
.param("format", outputConverter.getFormat()))
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
.stream()
.chatResponse();

View File

@@ -37,6 +37,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.test.CurlyBracketEscaper;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
@@ -182,7 +183,7 @@ class BedrockConverseChatClientIT {
.user(u -> u
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ "{format}")
.param("format", outputConverter.getFormat()))
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
.stream()
.chatResponse();

View File

@@ -28,12 +28,14 @@ import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice;
import org.springframework.ai.test.CurlyBracketEscaper;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
@@ -198,10 +200,11 @@ class MistralAiChatClientIT {
// @formatter:off
Flux<String> chatResponse = ChatClient.create(this.chatModel)
.prompt()
.advisors(new SimpleLoggerAdvisor())
.user(u -> u
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ "{format}")
.param("format", outputConverter.getFormat()))
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
.stream()
.content();

View File

@@ -43,6 +43,7 @@ import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.ai.test.CurlyBracketEscaper;
import org.springframework.ai.tool.function.FunctionToolCallback;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
@@ -220,7 +221,7 @@ class OpenAiChatClientIT extends AbstractIT {
.user(u -> u
.text("Generate the filmography of 5 movies for Tom Hanks. " + System.lineSeparator()
+ "{format}")
.param("format", outputConverter.getFormat()))
.param("format", CurlyBracketEscaper.escapeCurlyBrackets(outputConverter.getFormat())))
.stream()
.chatResponse();

View File

@@ -156,6 +156,8 @@ public interface ChatClient {
@Nullable
<T> T entity(Class<T> type);
ChatClientResponse chatClientResponse();
@Nullable
ChatResponse chatResponse();
@@ -172,6 +174,8 @@ public interface ChatClient {
interface StreamResponseSpec {
Flux<ChatClientResponse> chatClientResponse();
Flux<ChatResponse> chatResponse();
Flux<String> content();

View File

@@ -0,0 +1,52 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
/**
* Common attributes used in {@link ChatClient} context.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public enum ChatClientAttributes {
//@formatter:off
@Deprecated // Only for backward compatibility until the next release.
ADVISORS("spring.ai.chat.client.advisors"),
@Deprecated // Only for backward compatibility until the next release.
CHAT_MODEL("spring.ai.chat.client.model"),
@Deprecated // Only for backward compatibility until the next release.
OUTPUT_FORMAT("spring.ai.chat.client.output.format"),
@Deprecated // Only for backward compatibility until the next release.
USER_PARAMS("spring.ai.chat.client.user.params"),
@Deprecated // Only for backward compatibility until the next release.
SYSTEM_PARAMS("spring.ai.chat.client.system.params");
//@formatter:on
private final String key;
ChatClientAttributes(String key) {
this.key = key;
}
public String getKey() {
return key;
}
}

View File

@@ -0,0 +1,83 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.util.Assert;
import java.util.HashMap;
import java.util.Map;
/**
* Represents a request processed by a {@link ChatClient} that ultimately is used to build
* a {@link Prompt} to be sent to an AI model.
*
* @param prompt The prompt to be sent to the AI model
* @param context The contextual data through the execution chain
* @author Thomas Vitale
* @since 1.0.0
*/
public record ChatClientRequest(Prompt prompt, Map<String, Object> context) {
public ChatClientRequest {
Assert.notNull(prompt, "prompt cannot be null");
Assert.notNull(context, "context cannot be null");
Assert.noNullElements(context.keySet(), "context keys cannot be null");
}
public Builder mutate() {
return new Builder().prompt(this.prompt).context(this.context);
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private Prompt prompt;
private Map<String, Object> context = new HashMap<>();
private Builder() {
}
public Builder prompt(Prompt prompt) {
Assert.notNull(prompt, "prompt cannot be null");
this.prompt = prompt;
return this;
}
public Builder context(Map<String, Object> context) {
Assert.notNull(context, "context cannot be null");
this.context.putAll(context);
return this;
}
public Builder context(String key, Object value) {
Assert.notNull(key, "key cannot be null");
this.context.put(key, value);
return this;
}
public ChatClientRequest build() {
return new ChatClientRequest(prompt, context);
}
}
}

View File

@@ -0,0 +1,77 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import java.util.HashMap;
import java.util.Map;
/**
* Represents a response returned by a {@link ChatClient}.
*
* @param chatResponse The response returned by the AI model
* @param context The contextual data propagated through the execution chain
* @author Thomas Vitale
* @since 1.0.0
*/
public record ChatClientResponse(@Nullable ChatResponse chatResponse, Map<String, Object> context) {
public ChatClientResponse {
Assert.notNull(context, "context cannot be null");
Assert.noNullElements(context.keySet(), "context keys cannot be null");
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private ChatResponse chatResponse;
private Map<String, Object> context = new HashMap<>();
private Builder() {
}
public Builder chatResponse(ChatResponse chatResponse) {
this.chatResponse = chatResponse;
return this;
}
public Builder context(Map<String, Object> context) {
Assert.notNull(context, "context cannot be null");
this.context.putAll(context);
return this;
}
public Builder context(String key, Object value) {
Assert.notNull(key, "key cannot be null");
this.context.put(key, value);
return this;
}
public ChatClientResponse build() {
return new ChatClientResponse(this.chatResponse, this.context);
}
}
}

View File

@@ -22,7 +22,6 @@ import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -33,17 +32,19 @@ import java.util.function.Consumer;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.springframework.ai.chat.client.advisor.ChatModelCallAdvisor;
import org.springframework.ai.chat.client.advisor.ChatModelStreamAdvisor;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.lang.NonNull;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
@@ -62,10 +63,6 @@ import org.springframework.ai.content.Media;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.ToolCallbackProvider;
import org.springframework.ai.tool.ToolCallbacks;
import org.springframework.core.Ordered;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.Resource;
import org.springframework.lang.Nullable;
@@ -98,14 +95,10 @@ public class DefaultChatClient implements ChatClient {
this.defaultChatClientRequest = defaultChatClientRequest;
}
private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest,
@Nullable String formatParam) {
private static AdvisedRequest toAdvisedRequest(DefaultChatClientRequestSpec inputRequest) {
Assert.notNull(inputRequest, "inputRequest cannot be null");
Map<String, Object> advisorContext = new ConcurrentHashMap<>(inputRequest.getAdvisorParams());
if (StringUtils.hasText(formatParam)) {
advisorContext.put("formatParam", formatParam);
}
// Process userText, media and messages before creating the AdvisedRequest.
String userText = inputRequest.userText;
@@ -131,11 +124,12 @@ public class DefaultChatClient implements ChatClient {
}
return new AdvisedRequest(inputRequest.chatModel, userText, inputRequest.systemText, inputRequest.chatOptions,
media, inputRequest.functionNames, inputRequest.functionCallbacks, messages, inputRequest.userParams,
media, inputRequest.toolNames, inputRequest.toolCallbacks, messages, inputRequest.userParams,
inputRequest.systemParams, inputRequest.advisors, inputRequest.advisorParams, advisorContext,
inputRequest.toolContext);
}
@Deprecated
public static DefaultChatClientRequestSpec toDefaultChatClientRequestSpec(AdvisedRequest advisedRequest,
ObservationRegistry observationRegistry, ChatClientObservationConvention customObservationConvention) {
@@ -391,11 +385,25 @@ public class DefaultChatClient implements ChatClient {
public static class DefaultCallResponseSpec implements CallResponseSpec {
private final DefaultChatClientRequestSpec request;
private final ChatClientRequest request;
public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) {
Assert.notNull(request, "request cannot be null");
this.request = request;
private final BaseAdvisorChain advisorChain;
private final ObservationRegistry observationRegistry;
private final ChatClientObservationConvention observationConvention;
public DefaultCallResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain,
ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) {
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
Assert.notNull(advisorChain, "advisorChain cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.request = chatClientRequest;
this.advisorChain = advisorChain;
this.observationRegistry = observationRegistry;
this.observationConvention = observationConvention;
}
@Override
@@ -419,7 +427,8 @@ public class DefaultChatClient implements ChatClient {
protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> outputConverter) {
Assert.notNull(outputConverter, "structuredOutputConverter cannot be null");
var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat());
var chatResponse = doGetObservableChatClientResponse(this.request, outputConverter.getFormat())
.chatResponse();
var responseContent = getContentFromChatResponse(chatResponse);
if (responseContent == null) {
return new ResponseEntity<>(chatResponse, null);
@@ -452,7 +461,8 @@ public class DefaultChatClient implements ChatClient {
@Nullable
private <T> T doSingleWithBeanOutputConverter(StructuredOutputConverter<T> outputConverter) {
var chatResponse = doGetObservableChatResponse(this.request, outputConverter.getFormat());
var chatResponse = doGetObservableChatClientResponse(this.request, outputConverter.getFormat())
.chatResponse();
var stringResponse = getContentFromChatResponse(chatResponse);
if (stringResponse == null) {
return null;
@@ -460,38 +470,85 @@ public class DefaultChatClient implements ChatClient {
return outputConverter.convert(stringResponse);
}
@Nullable
private ChatResponse doGetChatResponse() {
return this.doGetObservableChatResponse(this.request, null);
@Override
public ChatClientResponse chatClientResponse() {
return doGetObservableChatClientResponse(this.request);
}
@Override
@Nullable
private ChatResponse doGetObservableChatResponse(DefaultChatClientRequestSpec inputRequest,
@Nullable String formatParam) {
public ChatResponse chatResponse() {
return doGetObservableChatClientResponse(this.request).chatResponse();
}
@Override
@Nullable
public String content() {
ChatResponse chatResponse = doGetObservableChatClientResponse(this.request).chatResponse();
return getContentFromChatResponse(chatResponse);
}
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest) {
return doGetObservableChatClientResponse(chatClientRequest, null);
}
private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest chatClientRequest,
@Nullable String outputFormat) {
ChatClientRequest formattedChatClientRequest = StringUtils.hasText(outputFormat)
? addFormatInstructionsToPrompt(chatClientRequest, outputFormat) : chatClientRequest;
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.withRequest(inputRequest)
.withFormat(formatParam)
.withStream(false)
.request(formattedChatClientRequest)
.stream(false)
.withFormat(outputFormat)
.build();
var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
() -> observationContext, inputRequest.getObservationRegistry());
return observation.observe(() -> doGetChatResponse(inputRequest, formatParam, observation));
var observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(observationConvention,
DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext, observationRegistry);
var chatClientResponse = observation.observe(() -> {
// Apply the advisor chain that terminates with the ChatModelCallAdvisor.
return advisorChain.nextCall(formattedChatClientRequest);
});
return chatClientResponse != null ? chatClientResponse : ChatClientResponse.builder().build();
}
private ChatResponse doGetChatResponse(DefaultChatClientRequestSpec inputRequestSpec,
@Nullable String formatParam, Observation parentObservation) {
@NonNull
private static ChatClientRequest addFormatInstructionsToPrompt(ChatClientRequest chatClientRequest,
String outputFormat) {
List<Message> originalMessages = chatClientRequest.prompt().getInstructions();
AdvisedRequest advisedRequest = toAdvisedRequest(inputRequestSpec, formatParam);
if (CollectionUtils.isEmpty(originalMessages)) {
return chatClientRequest;
}
// Apply the around advisor chain that terminates with the last model call
// advisor.
AdvisedResponse advisedResponse = inputRequestSpec.aroundAdvisorChainBuilder.build()
.nextAroundCall(advisedRequest);
// Create a copy of the message list to avoid modifying the original.
List<Message> modifiedMessages = new ArrayList<>(originalMessages);
return advisedResponse.response();
// Get the last message (without removing it from original list)
Message lastMessage = modifiedMessages.get(modifiedMessages.size() - 1);
// If the last message is a UserMessage, replace it with the modified version
if (lastMessage instanceof UserMessage userMessage) {
// Remove last message
modifiedMessages.remove(modifiedMessages.size() - 1);
// Create new user message with format instructions
UserMessage userMessageWithFormat = userMessage.mutate()
.text(userMessage.getText() + System.lineSeparator() + outputFormat)
.build();
// Add modified message back
modifiedMessages.add(userMessageWithFormat);
// Build new ChatClientRequest preserving all properties but with modified
// prompt
return ChatClientRequest.builder()
.prompt(chatClientRequest.prompt().mutate().messages(modifiedMessages).build())
.context(Map.copyOf(chatClientRequest.context()))
.build();
}
return chatClientRequest;
}
@Nullable
@@ -503,53 +560,49 @@ public class DefaultChatClient implements ChatClient {
.orElse(null);
}
@Override
@Nullable
public ChatResponse chatResponse() {
return doGetChatResponse();
}
@Override
@Nullable
public String content() {
ChatResponse chatResponse = doGetChatResponse();
return getContentFromChatResponse(chatResponse);
}
}
public static class DefaultStreamResponseSpec implements StreamResponseSpec {
private final DefaultChatClientRequestSpec request;
private final ChatClientRequest request;
public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) {
Assert.notNull(request, "request cannot be null");
this.request = request;
private final BaseAdvisorChain advisorChain;
private final ObservationRegistry observationRegistry;
private final ChatClientObservationConvention observationConvention;
public DefaultStreamResponseSpec(ChatClientRequest chatClientRequest, BaseAdvisorChain advisorChain,
ObservationRegistry observationRegistry, ChatClientObservationConvention observationConvention) {
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
Assert.notNull(advisorChain, "advisorChain cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.request = chatClientRequest;
this.advisorChain = advisorChain;
this.observationRegistry = observationRegistry;
this.observationConvention = observationConvention;
}
private Flux<ChatResponse> doGetObservableFluxChatResponse(DefaultChatClientRequestSpec inputRequest) {
private Flux<ChatClientResponse> doGetObservableFluxChatResponse(ChatClientRequest chatClientRequest) {
return Flux.deferContextual(contextView -> {
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.withRequest(inputRequest)
.withStream(true)
.request(chatClientRequest)
.stream(true)
.build();
Observation observation = ChatClientObservationDocumentation.AI_CHAT_CLIENT.observation(
inputRequest.getCustomObservationConvention(), DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION,
() -> observationContext, inputRequest.getObservationRegistry());
observationConvention, DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION, () -> observationContext,
observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null))
.start();
var initialAdvisedRequest = toAdvisedRequest(inputRequest, null);
// @formatter:off
// Apply the around advisor chain that terminates with the last model call advisor.
Flux<AdvisedResponse> stream = inputRequest.aroundAdvisorChainBuilder.build().nextAroundStream(initialAdvisedRequest);
return stream
.map(AdvisedResponse::response)
// Apply the advisor chain that terminates with the ChatModelStreamAdvisor.
return advisorChain.nextStream(chatClientRequest)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
@@ -558,19 +611,29 @@ public class DefaultChatClient implements ChatClient {
}
@Override
public Flux<ChatResponse> chatResponse() {
public Flux<ChatClientResponse> chatClientResponse() {
return doGetObservableFluxChatResponse(this.request);
}
@Override
public Flux<ChatResponse> chatResponse() {
return doGetObservableFluxChatResponse(this.request).mapNotNull(ChatClientResponse::chatResponse);
}
@Override
public Flux<String> content() {
return doGetObservableFluxChatResponse(this.request).map(r -> {
if (r.getResult() == null || r.getResult().getOutput() == null
|| r.getResult().getOutput().getText() == null) {
return "";
}
return r.getResult().getOutput().getText();
}).filter(StringUtils::hasLength);
// @formatter:off
return doGetObservableFluxChatResponse(this.request)
.mapNotNull(ChatClientResponse::chatResponse)
.map(r -> {
if (r.getResult() == null || r.getResult().getOutput() == null
|| r.getResult().getOutput().getText() == null) {
return "";
}
return r.getResult().getOutput().getText();
})
.filter(StringUtils::hasLength);
// @formatter:on
}
}
@@ -579,15 +642,15 @@ public class DefaultChatClient implements ChatClient {
private final ObservationRegistry observationRegistry;
private final ChatClientObservationConvention customObservationConvention;
private final ChatClientObservationConvention observationConvention;
private final ChatModel chatModel;
private final List<Media> media = new ArrayList<>();
private final List<String> functionNames = new ArrayList<>();
private final List<String> toolNames = new ArrayList<>();
private final List<FunctionCallback> functionCallbacks = new ArrayList<>();
private final List<FunctionCallback> toolCallbacks = new ArrayList<>();
private final List<Message> messages = new ArrayList<>();
@@ -614,25 +677,24 @@ public class DefaultChatClient implements ChatClient {
/* copy constructor */
DefaultChatClientRequestSpec(DefaultChatClientRequestSpec ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.observationRegistry, ccr.customObservationConvention, ccr.toolContext);
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.toolCallbacks,
ccr.messages, ccr.toolNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.observationRegistry, ccr.observationConvention, ccr.toolContext);
}
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, @Nullable String systemText, Map<String, Object> systemParams,
List<FunctionCallback> functionCallbacks, List<Message> messages, List<String> functionNames,
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention customObservationConvention,
Map<String, Object> toolContext) {
List<FunctionCallback> toolCallbacks, List<Message> messages, List<String> toolNames, List<Media> media,
@Nullable ChatOptions chatOptions, List<Advisor> advisors, Map<String, Object> advisorParams,
ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(userParams, "userParams cannot be null");
Assert.notNull(systemParams, "systemParams cannot be null");
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.notNull(messages, "messages cannot be null");
Assert.notNull(functionNames, "functionNames cannot be null");
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.notNull(media, "media cannot be null");
Assert.notNull(advisors, "advisors cannot be null");
Assert.notNull(advisorParams, "advisorParams cannot be null");
@@ -648,58 +710,21 @@ public class DefaultChatClient implements ChatClient {
this.systemText = systemText;
this.systemParams.putAll(systemParams);
this.functionNames.addAll(functionNames);
this.functionCallbacks.addAll(functionCallbacks);
this.toolNames.addAll(toolNames);
this.toolCallbacks.addAll(toolCallbacks);
this.messages.addAll(messages);
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.observationRegistry = observationRegistry;
this.customObservationConvention = customObservationConvention != null ? customObservationConvention
this.observationConvention = observationConvention != null ? observationConvention
: DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
this.toolContext.putAll(toolContext);
// @formatter:off
// At the stack bottom add the non-streaming and streaming model call advisors.
// They play the role of the last advisor in the around advisor chain.
this.advisors.add(new CallAroundAdvisor() {
@Override
public String getName() {
return CallAroundAdvisor.class.getSimpleName();
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
return new AdvisedResponse(chatModel.call(advisedRequest.toPrompt()), Collections.unmodifiableMap(advisedRequest.adviseContext()));
}
});
this.advisors.add(new StreamAroundAdvisor() {
@Override
public String getName() {
return StreamAroundAdvisor.class.getSimpleName();
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
@Override
public Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
return chatModel.stream(advisedRequest.toPrompt())
.map(chatResponse -> new AdvisedResponse(chatResponse, Collections.unmodifiableMap(advisedRequest.adviseContext())))
.publishOn(Schedulers.boundedElastic()); // TODO add option to disable.
}
});
// @formatter:on
// At the stack bottom add the model call advisors.
// They play the role of the last advisors in the advisor chain.
this.advisors.add(new ChatModelCallAdvisor(chatModel));
this.advisors.add(new ChatModelStreamAdvisor(chatModel));
this.aroundAdvisorChainBuilder = DefaultAroundAdvisorChain.builder(observationRegistry)
.pushAll(this.advisors);
@@ -710,7 +735,7 @@ public class DefaultChatClient implements ChatClient {
}
private ChatClientObservationConvention getCustomObservationConvention() {
return this.customObservationConvention;
return this.observationConvention;
}
@Nullable
@@ -753,11 +778,11 @@ public class DefaultChatClient implements ChatClient {
}
public List<String> getFunctionNames() {
return this.functionNames;
return this.toolNames;
}
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
return this.toolCallbacks;
}
public Map<String, Object> getToolContext() {
@@ -770,8 +795,8 @@ public class DefaultChatClient implements ChatClient {
*/
public Builder mutate() {
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient
.builder(this.chatModel, this.observationRegistry, this.customObservationConvention)
.defaultFunctions(StringUtils.toStringArray(this.functionNames));
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
.defaultTools(StringUtils.toStringArray(this.toolNames));
if (StringUtils.hasText(this.userText)) {
builder.defaultUser(
@@ -787,7 +812,7 @@ public class DefaultChatClient implements ChatClient {
}
builder.addMessages(this.messages);
builder.addToolCallbacks(this.functionCallbacks);
builder.addToolCallbacks(this.toolCallbacks);
builder.addToolContext(this.toolContext);
return builder;
@@ -843,7 +868,7 @@ public class DefaultChatClient implements ChatClient {
public ChatClientRequestSpec tools(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
this.functionNames.addAll(List.of(toolNames));
this.toolNames.addAll(List.of(toolNames));
return this;
}
@@ -851,7 +876,7 @@ public class DefaultChatClient implements ChatClient {
public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.functionCallbacks.addAll(List.of(toolCallbacks));
this.toolCallbacks.addAll(List.of(toolCallbacks));
return this;
}
@@ -859,7 +884,7 @@ public class DefaultChatClient implements ChatClient {
public ChatClientRequestSpec tools(List<ToolCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.functionCallbacks.addAll(toolCallbacks);
this.toolCallbacks.addAll(toolCallbacks);
return this;
}
@@ -867,7 +892,7 @@ public class DefaultChatClient implements ChatClient {
public ChatClientRequestSpec tools(Object... toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
this.toolCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
return this;
}
@@ -876,7 +901,7 @@ public class DefaultChatClient implements ChatClient {
Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null");
Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements");
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
this.functionCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks()));
this.toolCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks()));
}
return this;
}
@@ -890,7 +915,7 @@ public class DefaultChatClient implements ChatClient {
public ChatClientRequestSpec functions(FunctionCallback... functionCallbacks) {
Assert.notNull(functionCallbacks, "functionCallbacks cannot be null");
Assert.noNullElements(functionCallbacks, "functionCallbacks cannot contain null elements");
this.functionCallbacks.addAll(Arrays.asList(functionCallbacks));
this.toolCallbacks.addAll(Arrays.asList(functionCallbacks));
return this;
}
@@ -973,17 +998,22 @@ public class DefaultChatClient implements ChatClient {
}
public CallResponseSpec call() {
return new DefaultCallResponseSpec(this);
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
return new DefaultCallResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
observationRegistry, observationConvention);
}
public StreamResponseSpec stream() {
return new DefaultStreamResponseSpec(this);
BaseAdvisorChain advisorChain = aroundAdvisorChainBuilder.build();
return new DefaultStreamResponseSpec(toAdvisedRequest(this).toChatClientRequest(), advisorChain,
observationRegistry, observationConvention);
}
}
// Prompt
@Deprecated // never used, to be removed
public static class DefaultCallPromptResponseSpec implements CallPromptResponseSpec {
private final ChatModel chatModel;
@@ -1015,6 +1045,7 @@ public class DefaultChatClient implements ChatClient {
}
@Deprecated // never used, to be removed
public static class DefaultStreamPromptResponseSpec implements StreamPromptResponseSpec {
private final Prompt prompt;

View File

@@ -0,0 +1,65 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.core.Ordered;
import org.springframework.util.Assert;
import java.util.Map;
/**
* A {@link CallAdvisor} that uses a {@link ChatModel} to generate a response.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public final class ChatModelCallAdvisor implements CallAdvisor {
private final ChatModel chatModel;
public ChatModelCallAdvisor(ChatModel chatModel) {
this.chatModel = chatModel;
}
@Override
public ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
ChatResponse chatResponse = chatModel.call(chatClientRequest.prompt());
return ChatClientResponse.builder()
.chatResponse(chatResponse)
.context(Map.copyOf(chatClientRequest.context()))
.build();
}
@Override
public String getName() {
return "call";
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
}

View File

@@ -0,0 +1,66 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.*;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.core.Ordered;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import java.util.Map;
/**
* A {@link StreamAdvisor} that uses a {@link ChatModel} to generate a streaming response.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public final class ChatModelStreamAdvisor implements StreamAdvisor {
private final ChatModel chatModel;
public ChatModelStreamAdvisor(ChatModel chatModel) {
this.chatModel = chatModel;
}
@Override
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
return chatModel.stream(chatClientRequest.prompt())
.map(chatResponse -> ChatClientResponse.builder()
.chatResponse(chatResponse)
.context(Map.copyOf(chatClientRequest.context()))
.build())
.publishOn(Schedulers.boundedElastic()); // TODO add option to disable
}
@Override
public String getName() {
return "stream";
}
@Override
public int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,15 +23,18 @@ import java.util.concurrent.ConcurrentLinkedDeque;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationContext;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation;
@@ -41,16 +44,16 @@ import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* Implementation of the {@link CallAroundAdvisorChain} and
* {@link StreamAroundAdvisorChain}. Used by the
* Default implementation for the {@link BaseAdvisorChain}. Used by the
* {@link org.springframework.ai.chat.client.ChatClient} to delegate the call to the next
* {@link CallAroundAdvisor} or {@link StreamAroundAdvisor} in the chain.
* {@link CallAdvisor} or {@link StreamAdvisor} in the chain.
*
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @author Thomas Vitale
* @since 1.0.0
*/
public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, StreamAroundAdvisorChain {
public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();
@@ -77,7 +80,42 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream
}
@Override
public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
if (this.callAroundAdvisors.isEmpty()) {
throw new IllegalStateException("No CallAdvisors available to execute");
}
var advisor = this.callAroundAdvisors.pop();
var observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.chatClientRequest(chatClientRequest)
.order(advisor.getOrder())
.build();
return AdvisorObservationDocumentation.AI_ADVISOR
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
.observe(() -> {
// Supports both deprecated and new API.
if (advisor instanceof CallAdvisor callAdvisor) {
return callAdvisor.adviseCall(chatClientRequest, this);
}
AdvisedResponse advisedResponse = advisor.aroundCall(AdvisedRequest.from(chatClientRequest), this);
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
observationContext.setChatClientResponse(chatClientResponse);
return chatClientResponse;
});
}
/**
* @deprecated Use {@link #nextCall(ChatClientRequest)} instead
*/
@Override
@Deprecated
public AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
Assert.notNull(advisedRequest, "the advisedRequest cannot be null");
if (this.callAroundAdvisors.isEmpty()) {
throw new IllegalStateException("No AroundAdvisor available to execute");
@@ -87,31 +125,40 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream
var observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.advisorType(AdvisorObservationContext.Type.AROUND)
.advisedRequest(advisedRequest)
.advisorRequestContext(advisedRequest.adviseContext())
.chatClientRequest(advisedRequest.toChatClientRequest())
.order(advisor.getOrder())
.build();
return AdvisorObservationDocumentation.AI_ADVISOR
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
.observe(() -> advisor.aroundCall(advisedRequest, this));
.observe(() -> {
// Supports both deprecated and new API.
if (advisor instanceof CallAdvisor callAdvisor) {
ChatClientResponse chatClientResponse = callAdvisor.adviseCall(advisedRequest.toChatClientRequest(),
this);
return AdvisedResponse.from(chatClientResponse);
}
AdvisedResponse advisedResponse = advisor.aroundCall(advisedRequest, this);
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
observationContext.setChatClientResponse(chatClientResponse);
return advisedResponse;
});
}
@Override
public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest) {
Assert.notNull(chatClientRequest, "the chatClientRequest cannot be null");
return Flux.deferContextual(contextView -> {
if (this.streamAroundAdvisors.isEmpty()) {
return Flux.error(new IllegalStateException("No AroundAdvisor available to execute"));
return Flux.error(new IllegalStateException("No StreamAdvisors available to execute"));
}
var advisor = this.streamAroundAdvisors.pop();
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.advisorType(AdvisorObservationContext.Type.AROUND)
.advisedRequest(advisedRequest)
.advisorRequestContext(advisedRequest.adviseContext())
.chatClientRequest(chatClientRequest)
.order(advisor.getOrder())
.build();
@@ -121,10 +168,66 @@ public class DefaultAroundAdvisorChain implements CallAroundAdvisorChain, Stream
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
// @formatter:off
return Flux.defer(() -> advisor.aroundStream(advisedRequest, this))
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
return Flux.defer(() -> {
// Supports both deprecated and new API.
if (advisor instanceof StreamAdvisor streamAdvisor) {
return streamAdvisor.adviseStream(chatClientRequest, this)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
}
return advisor.aroundStream(AdvisedRequest.from(chatClientRequest), this)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation))
.map(AdvisedResponse::toChatClientResponse);
});
// @formatter:on
});
}
/**
* @deprecated Use {@link #nextStream(ChatClientRequest)} instead.
*/
@Override
@Deprecated
public Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
Assert.notNull(advisedRequest, "the advisedRequest cannot be null");
return Flux.deferContextual(contextView -> {
if (this.streamAroundAdvisors.isEmpty()) {
return Flux.error(new IllegalStateException("No AroundAdvisor available to execute"));
}
var advisor = this.streamAroundAdvisors.pop();
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName(advisor.getName())
.chatClientRequest(advisedRequest.toChatClientRequest())
.order(advisor.getOrder())
.build();
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null,
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
// @formatter:off
return Flux.defer(() -> {
// Supports both deprecated and new API.
if (advisor instanceof StreamAdvisor streamAdvisor) {
return streamAdvisor.adviseStream(advisedRequest.toChatClientRequest(), this)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation))
.map(AdvisedResponse::from);
}
return advisor.aroundStream(advisedRequest, this)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
});
// @formatter:on
});
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,10 +20,14 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.Objects;
import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
@@ -34,6 +38,7 @@ import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -60,8 +65,10 @@ import org.springframework.util.StringUtils;
* @author Christian Tzolov
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @deprecated Use {@link ChatClientRequest} instead.
* @since 1.0.0
*/
@Deprecated
public record AdvisedRequest(
// @formatter:off
ChatModel chatModel,
@@ -77,6 +84,7 @@ public record AdvisedRequest(
Map<String, Object> userParams,
Map<String, Object> systemParams,
List<Advisor> advisors,
@Deprecated // Not really used. Use "adviseContext" instead.
Map<String, Object> advisorParams,
Map<String, Object> adviseContext,
Map<String, Object> toolContext
@@ -139,6 +147,52 @@ public record AdvisedRequest(
return builder;
}
@SuppressWarnings("unchecked")
public static AdvisedRequest from(ChatClientRequest from) {
Assert.notNull(from, "ChatClientRequest cannot be null");
List<Message> messages = new LinkedList<>(from.prompt().getInstructions());
Builder builder = new Builder();
if (from.context().get(ChatClientAttributes.CHAT_MODEL.getKey()) instanceof ChatModel chatModel) {
builder.chatModel = chatModel;
}
if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof UserMessage userMessage) {
builder.userText = userMessage.getText();
builder.media = userMessage.getMedia();
messages.remove(messages.size() - 1);
}
if (from.context().get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map<?, ?> contextUserParams) {
builder.userParams = (Map<String, Object>) contextUserParams;
}
if (!messages.isEmpty() && messages.get(messages.size() - 1) instanceof SystemMessage systemMessage) {
builder.systemText = systemMessage.getText();
messages.remove(messages.size() - 1);
}
if (from.context().get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map<?, ?> contextSystemParams) {
builder.systemParams = (Map<String, Object>) contextSystemParams;
}
builder.messages = messages;
builder.chatOptions = Objects.requireNonNullElse(from.prompt().getOptions(), ChatOptions.builder().build());
if (from.prompt().getOptions() instanceof ToolCallingChatOptions options) {
builder.functionNames = options.getToolNames().stream().toList();
builder.functionCallbacks = options.getToolCallbacks();
builder.toolContext = options.getToolContext();
}
if (from.context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List<?> advisors) {
builder.advisors = (List<Advisor>) advisors;
}
builder.advisorParams = Map.of();
builder.adviseContext = from.context();
return builder.build();
}
public AdvisedRequest updateContext(Function<Map<String, Object>, Map<String, Object>> contextTransform) {
Assert.notNull(contextTransform, "contextTransform cannot be null");
return from(this)
@@ -146,6 +200,17 @@ public record AdvisedRequest(
.build();
}
public ChatClientRequest toChatClientRequest() {
return ChatClientRequest.builder()
.prompt(toPrompt())
.context(this.adviseContext)
.context(ChatClientAttributes.ADVISORS.getKey(), this.advisors)
.context(ChatClientAttributes.CHAT_MODEL.getKey(), this.chatModel)
.context(ChatClientAttributes.USER_PARAMS.getKey(), this.userParams)
.context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), this.systemParams)
.build();
}
public Prompt toPrompt() {
var messages = new ArrayList<>(this.messages());
@@ -157,16 +222,9 @@ public record AdvisedRequest(
messages.add(new SystemMessage(processedSystemText));
}
String formatParam = (String) this.adviseContext().get("formatParam");
var processedUserText = StringUtils.hasText(formatParam)
? this.userText() + System.lineSeparator() + "{spring_ai_soc_format}" : this.userText();
if (StringUtils.hasText(processedUserText)) {
if (StringUtils.hasText(this.userText())) {
Map<String, Object> userParams = new HashMap<>(this.userParams());
if (StringUtils.hasText(formatParam)) {
userParams.put("spring_ai_soc_format", formatParam);
}
String processedUserText = this.userText();
if (!CollectionUtils.isEmpty(userParams)) {
processedUserText = new PromptTemplate(processedUserText, userParams).render();
}
@@ -338,7 +396,9 @@ public record AdvisedRequest(
* Set the advisor params.
* @param advisorParams the advisor params
* @return this {@link Builder} instance
* @deprecated in favor of {@link #adviseContext(Map)}
*/
@Deprecated
public Builder advisorParams(Map<String, Object> advisorParams) {
this.advisorParams = advisorParams;
return this;

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
@@ -33,8 +34,10 @@ import org.springframework.util.Assert;
* @author Christian Tzolov
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @deprecated Use {@link ChatClientResponse} instead.
* @since 1.0.0
*/
@Deprecated
public record AdvisedResponse(@Nullable ChatResponse response, Map<String, Object> adviseContext) {
/**
@@ -66,6 +69,15 @@ public record AdvisedResponse(@Nullable ChatResponse response, Map<String, Objec
return new Builder().response(advisedResponse.response).adviseContext(advisedResponse.adviseContext);
}
public static AdvisedResponse from(ChatClientResponse chatClientResponse) {
Assert.notNull(chatClientResponse, "chatClientResponse cannot be null");
return new AdvisedResponse(chatClientResponse.chatResponse(), chatClientResponse.context());
}
public ChatClientResponse toChatClientResponse() {
return new ChatClientResponse(this.response, this.adviseContext);
}
/**
* Update the context of the advised response.
* @param contextTransform the function to transform the context

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,9 +24,9 @@ import org.springframework.core.Ordered;
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @since 1.0.0
* @see CallAroundAdvisor
* @see StreamAroundAdvisor
* @see CallAroundAdvisorChain
* @see CallAdvisor
* @see StreamAdvisor
* @see BaseAdvisor
*/
public interface Advisor extends Ordered {

View File

@@ -0,0 +1,28 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor.api;
/**
* A base interface for advisor chains that can be used to chain multiple advisors
* together, both for call and stream advisors.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public interface BaseAdvisorChain extends CallAdvisorChain, StreamAdvisorChain {
}

View File

@@ -0,0 +1,41 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
/**
* Advisor for execution flows ultimately resulting in a call to an AI model
*
* @author Thomas Vitale
* @since 1.0.0
*/
public interface CallAdvisor extends CallAroundAdvisor {
/**
* @deprecated use {@link #adviseCall(ChatClientRequest, CallAroundAdvisorChain)}
*/
@Deprecated
default AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
ChatClientResponse chatClientResponse = adviseCall(advisedRequest.toChatClientRequest(), chain);
return AdvisedResponse.from(chatClientResponse);
}
ChatClientResponse adviseCall(ChatClientRequest chatClientRequest, CallAroundAdvisorChain chain);
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
/**
* A chain of {@link CallAdvisor} instances orchestrating the execution of a
* {@link ChatClientRequest} on the next {@link CallAdvisor} in the chain.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public interface CallAdvisorChain extends CallAroundAdvisorChain {
/**
* @deprecated use {@link #nextCall(ChatClientRequest)}
*/
@Deprecated
default AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest) {
ChatClientResponse chatClientResponse = nextCall(advisedRequest.toChatClientRequest());
return AdvisedResponse.from(chatClientResponse);
}
ChatClientResponse nextCall(ChatClientRequest chatClientRequest);
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,14 +16,17 @@
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
/**
* Around advisor that wraps the ChatModel#call(Prompt) method.
*
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @since 1.0.0
* @deprecated in favor of {@link CallAdvisor}
*/
@Deprecated
public interface CallAroundAdvisor extends Advisor {
/**
@@ -31,7 +34,10 @@ public interface CallAroundAdvisor extends Advisor {
* @param advisedRequest the advised request
* @param chain the advisor chain
* @return the response
* @deprecated in favor of
* {@link CallAdvisor#adviseCall(ChatClientRequest, CallAroundAdvisorChain)}
*/
@Deprecated
AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain);
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
/**
* The Call Around Advisor Chain is used to invoke the next Around Advisor in the chain.
* Used for non-streaming responses.
@@ -23,7 +25,9 @@ package org.springframework.ai.chat.client.advisor.api;
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @since 1.0.0
* @deprecated in favor of {@link CallAdvisorChain}
*/
@Deprecated
public interface CallAroundAdvisorChain {
/**
@@ -32,7 +36,9 @@ public interface CallAroundAdvisorChain {
* @param advisedRequest the request containing the data to be processed by the next
* advisor in the chain.
* @return the response generated by the next advisor in the chain.
* @deprecated in favor of {@link CallAdvisorChain#nextCall(ChatClientRequest)}
*/
@Deprecated
AdvisedResponse nextAroundCall(AdvisedRequest advisedRequest);
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import reactor.core.publisher.Flux;
/**
* Advisor for execution flows ultimately resulting in a streaming call to an AI model.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public interface StreamAdvisor extends StreamAroundAdvisor {
/**
* @deprecated use {@link #adviseStream(ChatClientRequest, StreamAroundAdvisorChain)}
*/
@Deprecated
default Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain) {
Flux<ChatClientResponse> chatClientResponse = adviseStream(advisedRequest.toChatClientRequest(), chain);
return chatClientResponse.map(AdvisedResponse::from);
}
Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest, StreamAroundAdvisorChain chain);
}

View File

@@ -0,0 +1,43 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import reactor.core.publisher.Flux;
/**
* A chain of {@link StreamAdvisor} instances orchestrating the execution of a
* {@link ChatClientRequest} on the next {@link StreamAdvisor} in the chain.
*
* @author Thomas Vitale
* @since 1.0.0
*/
public interface StreamAdvisorChain extends StreamAroundAdvisorChain {
/**
* @deprecated use {@link #nextStream(ChatClientRequest)}
*/
@Deprecated
default Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest) {
Flux<ChatClientResponse> chatClientResponse = nextStream(advisedRequest.toChatClientRequest());
return chatClientResponse.map(AdvisedResponse::from);
}
Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest);
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
import reactor.core.publisher.Flux;
/**
@@ -24,7 +25,9 @@ import reactor.core.publisher.Flux;
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @since 1.0.0
* @deprecated in favor of {@link StreamAdvisor}
*/
@Deprecated
public interface StreamAroundAdvisor extends Advisor {
/**
@@ -32,7 +35,10 @@ public interface StreamAroundAdvisor extends Advisor {
* @param advisedRequest the advised request
* @param chain the chain of advisors to execute
* @return the result of the advised request
* @deprecated in favor of
* {@link StreamAdvisor#adviseStream(ChatClientRequest, StreamAroundAdvisorChain)}
*/
@Deprecated
Flux<AdvisedResponse> aroundStream(AdvisedRequest advisedRequest, StreamAroundAdvisorChain chain);
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
package org.springframework.ai.chat.client.advisor.api;
import org.springframework.ai.chat.client.ChatClientRequest;
import reactor.core.publisher.Flux;
/**
@@ -25,7 +26,9 @@ import reactor.core.publisher.Flux;
* @author Christian Tzolov
* @author Dariusz Jedrzejczyk
* @since 1.0.0
* @deprecated in favor of {@link StreamAdvisorChain}
*/
@Deprecated
public interface StreamAroundAdvisorChain {
/**
@@ -34,7 +37,9 @@ public interface StreamAroundAdvisorChain {
* @param advisedRequest the request containing data of the chat client that can be
* modified before execution
* @return a Flux stream of AdvisedResponse objects
* @deprecated in favor of {@link StreamAdvisorChain#nextStream(ChatClientRequest)}
*/
@Deprecated
Flux<AdvisedResponse> nextAroundStream(AdvisedRequest advisedRequest);
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,10 +21,13 @@ import java.util.Map;
import io.micrometer.observation.Observation;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
/**
* Context used to store metadata for chat client advisors.
@@ -37,26 +40,12 @@ public class AdvisorObservationContext extends Observation.Context {
private final String advisorName;
private final Type advisorType;
private final ChatClientRequest chatClientRequest;
/**
* The order of the advisor in the advisor chain.
*/
private final int order;
/**
* The {@link AdvisedRequest} data to be advised. Represents the row
* {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}.
*/
@Nullable
private AdvisedRequest advisorRequest;
/**
* The shared data between the advisors in the chain. It is shared between all request
* and response advising points of all advisors in the chain.
*/
@Nullable
private Map<String, Object> advisorRequestContext;
private ChatClientResponse chatClientResponse;
/**
* the shared data between the advisors in the chain. It is shared between all request
@@ -73,18 +62,32 @@ public class AdvisorObservationContext extends Observation.Context {
* @param advisorRequestContext the shared data between the advisors in the chain
* @param advisorResponseContext the shared data between the advisors in the chain
* @param order the order of the advisor in the advisor chain
* @deprecated use the builder instead
*/
@Deprecated
public AdvisorObservationContext(String advisorName, Type advisorType, @Nullable AdvisedRequest advisorRequest,
@Nullable Map<String, Object> advisorRequestContext, @Nullable Map<String, Object> advisorResponseContext,
int order) {
Assert.hasText(advisorName, "advisorName must not be null or empty");
Assert.notNull(advisorType, "advisorType must not be null");
Assert.hasText(advisorName, "advisorName cannot be null or empty");
this.advisorName = advisorName;
this.advisorType = advisorType;
this.advisorRequest = advisorRequest;
this.advisorRequestContext = advisorRequestContext;
this.advisorResponseContext = advisorResponseContext;
this.chatClientRequest = advisorRequest != null ? advisorRequest.toChatClientRequest()
: ChatClientRequest.builder().prompt(new Prompt()).build();
if (!CollectionUtils.isEmpty(advisorRequestContext)) {
this.chatClientRequest.context().putAll(advisorRequestContext);
}
if (!CollectionUtils.isEmpty(advisorResponseContext)) {
this.chatClientResponse = ChatClientResponse.builder().context(advisorResponseContext).build();
}
this.order = order;
}
AdvisorObservationContext(String advisorName, ChatClientRequest chatClientRequest, int order) {
Assert.hasText(advisorName, "advisorName cannot be null or empty");
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
this.advisorName = advisorName;
this.chatClientRequest = chatClientRequest;
this.order = order;
}
@@ -96,89 +99,115 @@ public class AdvisorObservationContext extends Observation.Context {
return new Builder();
}
/**
* The advisor name.
* @return the advisor name
*/
public String getAdvisorName() {
return this.advisorName;
}
public ChatClientRequest getChatClientRequest() {
return this.chatClientRequest;
}
public int getOrder() {
return this.order;
}
@Nullable
public ChatClientResponse getChatClientResponse() {
return this.chatClientResponse;
}
public void setChatClientResponse(@Nullable ChatClientResponse chatClientResponse) {
this.chatClientResponse = chatClientResponse;
}
/**
* The type of the advisor.
* @return the type of the advisor
* @deprecated advisors don't have types anymore, they're all "around"
*/
@Deprecated
public Type getAdvisorType() {
return this.advisorType;
return Type.AROUND;
}
/**
* The order of the advisor in the advisor chain.
* @return the order of the advisor in the advisor chain
* @deprecated not used anymore
*/
@Nullable
@Deprecated
public AdvisedRequest getAdvisedRequest() {
return this.advisorRequest;
return AdvisedRequest.from(this.chatClientRequest);
}
/**
* Set the {@link AdvisedRequest} data to be advised. Represents the row
* {@link ChatClient.ChatClientRequestSpec} data before sealed into a {@link Prompt}.
* @param advisedRequest the advised request
* @deprecated immutable object, use the builder instead to create a new instance
*/
@Deprecated
public void setAdvisedRequest(@Nullable AdvisedRequest advisedRequest) {
this.advisorRequest = advisedRequest;
throw new IllegalStateException(
"The AdvisedRequest is immutable. Build a new AdvisorObservationContext instead.");
}
/**
* Get the shared data between the advisors in the chain. It is shared between all
* request and response advising points of all advisors in the chain.
* @return the shared data between the advisors in the chain
* @deprecated use {@link #getChatClientRequest()} instead
*/
@Nullable
@Deprecated
public Map<String, Object> getAdvisorRequestContext() {
return this.advisorRequestContext;
return this.chatClientRequest.context();
}
/**
* Set the shared data between the advisors in the chain. It is shared between all
* request and response advising points of all advisors in the chain.
* @param advisorRequestContext the shared data between the advisors in the chain
* @deprecated not supported anymore, use {@link #getChatClientRequest()} instead
*/
@Deprecated
public void setAdvisorRequestContext(@Nullable Map<String, Object> advisorRequestContext) {
this.advisorRequestContext = advisorRequestContext;
if (!CollectionUtils.isEmpty(advisorRequestContext)) {
this.chatClientRequest.context().putAll(advisorRequestContext);
}
}
/**
* Get the shared data between the advisors in the chain. It is shared between all
* request and response advising points of all advisors in the chain.
* @return the shared data between the advisors in the chain
* @deprecated use {@link #getChatClientResponse()} instead
*/
@Nullable
@Deprecated
public Map<String, Object> getAdvisorResponseContext() {
return this.advisorResponseContext;
if (this.chatClientResponse != null) {
return this.chatClientResponse.context();
}
return null;
}
/**
* Set the shared data between the advisors in the chain. It is shared between all
* request and response advising points of all advisors in the chain.
* @param advisorResponseContext the shared data between the advisors in the chain
* @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead
*/
@Deprecated
public void setAdvisorResponseContext(@Nullable Map<String, Object> advisorResponseContext) {
this.advisorResponseContext = advisorResponseContext;
}
/**
* The order of the advisor in the advisor chain.
* @return the order of the advisor in the advisor chain
*/
public int getOrder() {
return this.order;
}
/**
* The type of the advisor.
*
* @deprecated advisors don't have types anymore, they're all "around"
*/
@Deprecated
public enum Type {
/**
@@ -203,7 +232,9 @@ public class AdvisorObservationContext extends Observation.Context {
private String advisorName;
private Type advisorType;
private ChatClientRequest chatClientRequest;
private int order = 0;
private AdvisedRequest advisorRequest;
@@ -211,28 +242,32 @@ public class AdvisorObservationContext extends Observation.Context {
private Map<String, Object> advisorResponseContext;
private int order = 0;
private Builder() {
}
/**
* Set the advisor name.
* @param advisorName the advisor name
* @return the builder
*/
public Builder advisorName(String advisorName) {
this.advisorName = advisorName;
return this;
}
public Builder chatClientRequest(ChatClientRequest chatClientRequest) {
this.chatClientRequest = chatClientRequest;
return this;
}
public Builder order(int order) {
this.order = order;
return this;
}
/**
* Set the advisor type.
* @param advisorType the advisor type
* @return the builder
* @deprecated advisors don't have types anymore, they're all "around"
*/
@Deprecated
public Builder advisorType(Type advisorType) {
this.advisorType = advisorType;
return this;
}
@@ -240,7 +275,9 @@ public class AdvisorObservationContext extends Observation.Context {
* Set the advised request.
* @param advisedRequest the advised request
* @return the builder
* @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead
*/
@Deprecated
public Builder advisedRequest(AdvisedRequest advisedRequest) {
this.advisorRequest = advisedRequest;
return this;
@@ -250,7 +287,9 @@ public class AdvisorObservationContext extends Observation.Context {
* Set the advisor request context.
* @param advisorRequestContext the advisor request context
* @return the builder
* @deprecated use {@link #chatClientRequest(ChatClientRequest)} instead
*/
@Deprecated
public Builder advisorRequestContext(Map<String, Object> advisorRequestContext) {
this.advisorRequestContext = advisorRequestContext;
return this;
@@ -260,29 +299,26 @@ public class AdvisorObservationContext extends Observation.Context {
* Set the advisor response context.
* @param advisorResponseContext the advisor response context
* @return the builder
* @deprecated use {@link #setChatClientResponse(ChatClientResponse)} instead
*/
@Deprecated
public Builder advisorResponseContext(Map<String, Object> advisorResponseContext) {
this.advisorResponseContext = advisorResponseContext;
return this;
}
/**
* Set the order of the advisor in the advisor chain.
* @param order the order of the advisor in the advisor chain
* @return the builder
*/
public Builder order(int order) {
this.order = order;
return this;
}
/**
* Build the {@link AdvisorObservationContext}.
* @return the {@link AdvisorObservationContext}
*/
public AdvisorObservationContext build() {
return new AdvisorObservationContext(this.advisorName, this.advisorType, this.advisorRequest,
this.advisorRequestContext, this.advisorResponseContext, this.order);
if (chatClientRequest != null && advisorRequest != null) {
throw new IllegalArgumentException(
"ChatClientRequest and AdvisedRequest cannot be set at the same time");
}
else if (chatClientRequest != null) {
return new AdvisorObservationContext(this.advisorName, this.chatClientRequest, this.order);
}
else {
return new AdvisorObservationContext(this.advisorName, Type.AROUND, this.advisorRequest,
this.advisorRequestContext, this.advisorResponseContext, this.order);
}
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.observation.conventions.SpringAiKind;
import org.springframework.ai.util.ParsingUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
/**
* Default implementation of the {@link AdvisorObservationConvention}.
@@ -55,6 +56,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
@Override
@Nullable
public String getContextualName(AdvisorObservationContext context) {
Assert.notNull(context, "context cannot be null");
return ParsingUtils.reConcatenateCamelCase(context.getAdvisorName(), "_")
.replace("_around_advisor", "")
.replace("_advisor", "");
@@ -66,6 +68,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
@Override
public KeyValues getLowCardinalityKeyValues(AdvisorObservationContext context) {
Assert.notNull(context, "context cannot be null");
return KeyValues.of(aiOperationType(context), aiProvider(context), springAiKind(), advisorType(context),
advisorName(context));
}
@@ -78,6 +81,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
return KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER, AiProvider.SPRING_AI.value());
}
@Deprecated
protected KeyValue advisorType(AdvisorObservationContext context) {
return KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE, context.getAdvisorType().name());
}
@@ -96,6 +100,7 @@ public class DefaultAdvisorObservationConvention implements AdvisorObservationCo
@Override
public KeyValues getHighCardinalityKeyValues(AdvisorObservationContext context) {
Assert.notNull(context, "context cannot be null");
return KeyValues.of(advisorOrder(context));
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,9 +20,15 @@ import io.micrometer.common.KeyValue;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationFilter;
import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.observation.tracing.TracingHelper;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import java.util.List;
import java.util.Map;
/**
* An {@link ObservationFilter} to include the chat prompt content in the observation.
@@ -37,7 +43,8 @@ public class ChatClientInputContentObservationFilter implements ObservationFilte
if (!(context instanceof ChatClientObservationContext chatClientObservationContext)) {
return context;
}
// TODO: we really want these? Should probably align with same format as chat
// model observation
chatClientSystemText(chatClientObservationContext);
chatClientSystemParams(chatClientObservationContext);
chatClientUserText(chatClientObservationContext);
@@ -47,39 +54,65 @@ public class ChatClientInputContentObservationFilter implements ObservationFilte
}
protected void chatClientSystemText(ChatClientObservationContext context) {
if (!StringUtils.hasText(context.getRequest().getSystemText())) {
List<Message> messages = context.getRequest().prompt().getInstructions();
if (CollectionUtils.isEmpty(messages)) {
return;
}
var systemMessage = messages.stream()
.filter(message -> message instanceof SystemMessage)
.reduce((first, second) -> second);
if (systemMessage.isEmpty()) {
return;
}
context.addHighCardinalityKeyValue(
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_TEXT,
context.getRequest().getSystemText()));
systemMessage.get().getText()));
}
@SuppressWarnings("unchecked")
protected void chatClientSystemParams(ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getSystemParams())) {
if (!(context.getRequest()
.context()
.get(ChatClientAttributes.SYSTEM_PARAMS.getKey()) instanceof Map<?, ?> systemParams)) {
return;
}
if (CollectionUtils.isEmpty(systemParams)) {
return;
}
context.addHighCardinalityKeyValue(
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_SYSTEM_PARAM,
TracingHelper.concatenateMaps(context.getRequest().getSystemParams())));
TracingHelper.concatenateMaps((Map<String, Object>) systemParams)));
}
protected void chatClientUserText(ChatClientObservationContext context) {
if (!StringUtils.hasText(context.getRequest().getUserText())) {
List<Message> messages = context.getRequest().prompt().getInstructions();
if (CollectionUtils.isEmpty(messages)) {
return;
}
if (!(messages.get(messages.size() - 1) instanceof UserMessage userMessage)) {
return;
}
context.addHighCardinalityKeyValue(
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_TEXT,
context.getRequest().getUserText()));
userMessage.getText()));
}
@SuppressWarnings("unchecked")
protected void chatClientUserParams(ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getUserParams())) {
if (!(context.getRequest()
.context()
.get(ChatClientAttributes.USER_PARAMS.getKey()) instanceof Map<?, ?> userParams)) {
return;
}
if (CollectionUtils.isEmpty(userParams)) {
return;
}
context.addHighCardinalityKeyValue(
KeyValue.of(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_USER_PARAMS,
TracingHelper.concatenateMaps(context.getRequest().getUserParams())));
TracingHelper.concatenateMaps((Map<String, Object>) userParams)));
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,12 +18,14 @@ package org.springframework.ai.chat.client.observation;
import io.micrometer.observation.Observation;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.observation.AiOperationMetadata;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Context used to store metadata for chat client workflows.
@@ -34,20 +36,16 @@ import org.springframework.util.Assert;
*/
public class ChatClientObservationContext extends Observation.Context {
private final DefaultChatClientRequestSpec request;
private final ChatClientRequest request;
private final AiOperationMetadata operationMetadata = new AiOperationMetadata(AiOperationType.FRAMEWORK.value(),
AiProvider.SPRING_AI.value());
private final boolean stream;
@Nullable
private String format;
ChatClientObservationContext(DefaultChatClientRequestSpec requestSpec, String format, boolean isStream) {
Assert.notNull(requestSpec, "requestSpec cannot be null");
this.request = requestSpec;
this.format = format;
ChatClientObservationContext(ChatClientRequest chatClientRequest, boolean isStream) {
Assert.notNull(chatClientRequest, "chatClientRequest cannot be null");
this.request = chatClientRequest;
this.stream = isStream;
}
@@ -55,7 +53,7 @@ public class ChatClientObservationContext extends Observation.Context {
return new Builder();
}
public DefaultChatClientRequestSpec getRequest() {
public ChatClientRequest getRequest() {
return this.request;
}
@@ -67,18 +65,31 @@ public class ChatClientObservationContext extends Observation.Context {
return this.stream;
}
/**
* @deprecated not used anymore. The format instructions are already included in the
* ChatModelObservationContext.
*/
@Nullable
@Deprecated
public String getFormat() {
return this.format;
if (this.request.context().get(ChatClientAttributes.OUTPUT_FORMAT.getKey()) instanceof String format) {
return format;
}
return null;
}
/**
* @deprecated not used anymore. The format instructions are already included in the
* ChatModelObservationContext.
*/
@Deprecated
public void setFormat(@Nullable String format) {
this.format = format;
this.request.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format);
}
public static final class Builder {
private DefaultChatClientRequestSpec request;
private ChatClientRequest chatClientRequest;
private String format;
@@ -87,23 +98,41 @@ public class ChatClientObservationContext extends Observation.Context {
private Builder() {
}
public Builder withRequest(DefaultChatClientRequestSpec request) {
this.request = request;
public Builder request(ChatClientRequest chatClientRequest) {
this.chatClientRequest = chatClientRequest;
return this;
}
@Deprecated // use request(ChatClientRequest chatClientRequest)
public Builder withRequest(ChatClientRequest chatClientRequest) {
return request(chatClientRequest);
}
/**
* @deprecated not used anymore. The format instructions are already included in
* the ChatModelObservationContext.
*/
@Deprecated
public Builder withFormat(String format) {
this.format = format;
return this;
}
public Builder withStream(boolean isStream) {
public Builder stream(boolean isStream) {
this.isStream = isStream;
return this;
}
@Deprecated // use stream(boolean isStream)
public Builder withStream(boolean isStream) {
return stream(isStream);
}
public ChatClientObservationContext build() {
return new ChatClientObservationContext(this.request, this.format, this.isStream);
if (StringUtils.hasText(format)) {
this.chatClientRequest.context().put(ChatClientAttributes.OUTPUT_FORMAT.getKey(), format);
}
return new ChatClientObservationContext(this.chatClientRequest, this.isStream);
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,15 +19,20 @@ package org.springframework.ai.chat.client.observation;
import io.micrometer.common.KeyValue;
import io.micrometer.common.KeyValues;
import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.observation.conventions.SpringAiKind;
import org.springframework.ai.observation.tracing.TracingHelper;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
/**
* Default conventions to populate observations for chat client workflows.
*
@@ -88,53 +93,71 @@ public class DefaultChatClientObservationConvention implements ChatClientObserva
public KeyValues getHighCardinalityKeyValues(ChatClientObservationContext context) {
var keyValues = KeyValues.empty();
keyValues = chatClientAdvisorNames(keyValues, context);
// TODO: rename attribute? any sensitive data here?
keyValues = chatClientAdvisorParams(keyValues, context);
keyValues = toolFunctionNames(keyValues, context);
keyValues = toolFunctionCallbacks(keyValues, context);
// TODO: remove this? Already included in chat model observation
keyValues = toolNames(keyValues, context);
// TODO: remove this? Already included in chat model observation
keyValues = toolCallbacks(keyValues, context);
return keyValues;
}
@SuppressWarnings("unchecked")
protected KeyValues chatClientAdvisorNames(KeyValues keyValues, ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getAdvisors())) {
if (!(context.getRequest().context().get(ChatClientAttributes.ADVISORS.getKey()) instanceof List<?> advisors)) {
return keyValues;
}
var advisorNames = context.getRequest().getAdvisors().stream().map(Advisor::getName).toList();
var advisorNames = ((List<Advisor>) advisors).stream().map(Advisor::getName).toList();
return keyValues.and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(),
TracingHelper.concatenateStrings(advisorNames));
}
protected KeyValues chatClientAdvisorParams(KeyValues keyValues, ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getAdvisorParams())) {
if (CollectionUtils.isEmpty(context.getRequest().context())) {
return keyValues;
}
var advisorParams = context.getRequest().getAdvisorParams();
var chatClientContext = context.getRequest().context();
Arrays.stream(ChatClientAttributes.values()).forEach(attribute -> chatClientContext.remove(attribute.getKey()));
return keyValues.and(
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(),
TracingHelper.concatenateMaps(advisorParams));
TracingHelper.concatenateMaps(chatClientContext));
}
protected KeyValues toolFunctionNames(KeyValues keyValues, ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getFunctionNames())) {
protected KeyValues toolNames(KeyValues keyValues, ChatClientObservationContext context) {
if (context.getRequest().prompt().getOptions() == null) {
return keyValues;
}
var functionNames = context.getRequest().getFunctionNames();
if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) {
return keyValues;
}
var toolNames = options.getToolNames();
if (CollectionUtils.isEmpty(toolNames)) {
return keyValues;
}
return keyValues.and(
ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(),
TracingHelper.concatenateStrings(functionNames));
TracingHelper.concatenateStrings(toolNames.stream().sorted().toList()));
}
protected KeyValues toolFunctionCallbacks(KeyValues keyValues, ChatClientObservationContext context) {
if (CollectionUtils.isEmpty(context.getRequest().getFunctionCallbacks())) {
protected KeyValues toolCallbacks(KeyValues keyValues, ChatClientObservationContext context) {
if (context.getRequest().prompt().getOptions() == null) {
return keyValues;
}
var functionCallbacks = context.getRequest()
.getFunctionCallbacks()
.stream()
.map(FunctionCallback::getName)
.toList();
if (!(context.getRequest().prompt().getOptions() instanceof ToolCallingChatOptions options)) {
return keyValues;
}
var toolCallbacks = options.getToolCallbacks();
if (CollectionUtils.isEmpty(toolCallbacks)) {
return keyValues;
}
var toolCallbackNames = toolCallbacks.stream().map(FunctionCallback::getName).sorted().toList();
return keyValues
.and(ChatClientObservationDocumentation.HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS
.asString(), TracingHelper.concatenateStrings(functionCallbacks));
.asString(), TracingHelper.concatenateStrings(toolCallbackNames));
}
}

View File

@@ -0,0 +1,63 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.prompt.Prompt;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link ChatClientRequest}.
*
* @author Thomas Vitale
*/
class ChatClientRequestTests {
@Test
void whenPromptIsNullThenThrow() {
assertThatThrownBy(() -> new ChatClientRequest(null, Map.of())).isInstanceOf(IllegalArgumentException.class)
.hasMessage("prompt cannot be null");
assertThatThrownBy(() -> ChatClientRequest.builder().prompt(null).context(Map.of()).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("prompt cannot be null");
}
@Test
void whenContextIsNullThenThrow() {
assertThatThrownBy(() -> new ChatClientRequest(new Prompt(), null)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
assertThatThrownBy(() -> ChatClientRequest.builder().prompt(new Prompt()).context(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
}
@Test
void whenContextHasNullKeysThenThrow() {
Map<String, Object> context = new HashMap<>();
context.put(null, "something");
assertThatThrownBy(() -> new ChatClientRequest(new Prompt(), context))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("context keys cannot be null");
}
}

View File

@@ -0,0 +1,51 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
import org.junit.jupiter.api.Test;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link ChatClientResponse}.
*
* @author Thomas Vitale
*/
class ChatClientResponseTests {
@Test
void whenContextIsNullThenThrow() {
assertThatThrownBy(() -> new ChatClientResponse(null, null)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
assertThatThrownBy(() -> ChatClientResponse.builder().chatResponse(null).context(null).build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("context cannot be null");
}
@Test
void whenContextHasNullKeysThenThrow() {
Map<String, Object> context = new HashMap<>();
context.put(null, "something");
assertThatThrownBy(() -> new ChatClientResponse(null, context)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("context keys cannot be null");
}
}

View File

@@ -29,6 +29,9 @@ import java.util.function.Consumer;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.tool.ToolCallback;
import reactor.core.publisher.Flux;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
@@ -46,7 +49,6 @@ import org.springframework.ai.content.Media;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.io.ClassPathResource;
@@ -598,17 +600,67 @@ class DefaultChatClientTests {
void buildCallResponseSpec() {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt();
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
.prompt("question");
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
assertThat(spec).isNotNull();
}
@Test
void buildCallResponseSpecWithNullRequest() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(null))
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(null, mock(BaseAdvisorChain.class),
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("request cannot be null");
.hasMessage("chatClientRequest cannot be null");
}
@Test
void buildCallResponseSpecWithNullAdvisorChain() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class), null,
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("advisorChain cannot be null");
}
@Test
void buildCallResponseSpecWithNullObservationRegistry() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class),
mock(BaseAdvisorChain.class), null, mock(ChatClientObservationConvention.class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("observationRegistry cannot be null");
}
@Test
void buildCallResponseSpecWithNullObservationConvention() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultCallResponseSpec(mock(ChatClientRequest.class),
mock(BaseAdvisorChain.class), mock(ObservationRegistry.class), null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("observationConvention cannot be null");
}
@Test
void whenSimplePromptThenChatClientResponse() {
ChatModel chatModel = mock(ChatModel.class);
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
given(chatModel.call(promptCaptor.capture()))
.willReturn(new ChatResponse(List.of(new Generation(new AssistantMessage("response")))));
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ChatClientResponse chatClientResponse = spec.chatClientResponse();
assertThat(chatClientResponse).isNotNull();
ChatResponse chatResponse = chatClientResponse.chatResponse();
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response");
Prompt actualPrompt = promptCaptor.getValue();
assertThat(actualPrompt.getInstructions()).hasSize(1);
assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("my question");
}
@Test
@@ -621,8 +673,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ChatResponse chatResponse = spec.chatResponse();
assertThat(chatResponse).isNotNull();
@@ -644,8 +696,8 @@ class DefaultChatClientTests {
Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question"));
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt(prompt);
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ChatResponse chatResponse = spec.chatResponse();
assertThat(chatResponse).isNotNull();
@@ -669,8 +721,8 @@ class DefaultChatClientTests {
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt(prompt)
.user("another question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ChatResponse chatResponse = spec.chatResponse();
assertThat(chatResponse).isNotNull();
@@ -696,8 +748,8 @@ class DefaultChatClientTests {
.prompt()
.user("another question")
.messages(messages);
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ChatResponse chatResponse = spec.chatResponse();
assertThat(chatResponse).isNotNull();
@@ -719,8 +771,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ChatResponse chatResponse = spec.chatResponse();
assertThat(chatResponse).isNull();
@@ -736,8 +788,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
String content = spec.content();
assertThat(content).isNull();
@@ -748,10 +800,10 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
assertThatThrownBy(() -> spec.responseEntity((ParameterizedTypeReference) null))
assertThatThrownBy(() -> spec.responseEntity((ParameterizedTypeReference<?>) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("type cannot be null");
}
@@ -766,8 +818,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ResponseEntity<ChatResponse, List<String>> responseEntity = spec
.responseEntity(new ParameterizedTypeReference<>() {
@@ -793,8 +845,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ResponseEntity<ChatResponse, List<Person>> responseEntity = spec
.responseEntity(new ParameterizedTypeReference<>() {
@@ -808,10 +860,10 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
assertThatThrownBy(() -> spec.responseEntity((StructuredOutputConverter) null))
assertThatThrownBy(() -> spec.responseEntity((StructuredOutputConverter<?>) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("structuredOutputConverter cannot be null");
}
@@ -826,8 +878,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ResponseEntity<ChatResponse, List<String>> responseEntity = spec
.responseEntity(new ListOutputConverter(new DefaultConversionService()));
@@ -847,8 +899,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ResponseEntity<ChatResponse, List<String>> responseEntity = spec
.responseEntity(new ListOutputConverter(new DefaultConversionService()));
@@ -861,8 +913,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
assertThatThrownBy(() -> spec.responseEntity((Class) null)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("type cannot be null");
@@ -878,8 +930,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ResponseEntity<ChatResponse, String> responseEntity = spec.responseEntity(String.class);
assertThat(responseEntity.response()).isNotNull();
@@ -898,8 +950,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
ResponseEntity<ChatResponse, Person> responseEntity = spec.responseEntity(Person.class);
assertThat(responseEntity.response()).isNotNull();
@@ -912,10 +964,10 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
assertThatThrownBy(() -> spec.entity((ParameterizedTypeReference) null))
assertThatThrownBy(() -> spec.entity((ParameterizedTypeReference<?>) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("type cannot be null");
}
@@ -930,8 +982,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
List<String> entity = spec.entity(new ParameterizedTypeReference<>() {
});
@@ -954,8 +1006,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
List<Person> entity = spec.entity(new ParameterizedTypeReference<>() {
});
@@ -967,10 +1019,10 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
assertThatThrownBy(() -> spec.entity((StructuredOutputConverter) null))
assertThatThrownBy(() -> spec.entity((StructuredOutputConverter<?>) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("structuredOutputConverter cannot be null");
}
@@ -980,8 +1032,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
List<String> entity = spec.entity(new ListOutputConverter(new DefaultConversionService()));
assertThat(entity).isNull();
@@ -999,8 +1051,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
List<String> entity = spec.entity(new ListOutputConverter(new DefaultConversionService()));
assertThat(entity).hasSize(3);
@@ -1011,10 +1063,10 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
assertThatThrownBy(() -> spec.entity((Class) null)).isInstanceOf(IllegalArgumentException.class)
assertThatThrownBy(() -> spec.entity((Class<?>) null)).isInstanceOf(IllegalArgumentException.class)
.hasMessage("type cannot be null");
}
@@ -1028,8 +1080,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
String entity = spec.entity(String.class);
assertThat(entity).isNull();
@@ -1047,8 +1099,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultCallResponseSpec spec = new DefaultChatClient.DefaultCallResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultCallResponseSpec spec = (DefaultChatClient.DefaultCallResponseSpec) chatClientRequestSpec
.call();
Person entity = spec.entity(Person.class);
assertThat(entity).isNotNull();
@@ -1061,17 +1113,67 @@ class DefaultChatClientTests {
void buildStreamResponseSpec() {
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt();
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
chatClientRequestSpec);
.prompt("question");
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
.stream();
assertThat(spec).isNotNull();
}
@Test
void buildStreamResponseSpecWithNullRequest() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(null))
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(null, mock(BaseAdvisorChain.class),
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("request cannot be null");
.hasMessage("chatClientRequest cannot be null");
}
@Test
void buildStreamResponseSpecWithNullAdvisorChain() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class), null,
mock(ObservationRegistry.class), mock(ChatClientObservationConvention.class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("advisorChain cannot be null");
}
@Test
void buildStreamResponseSpecWithNullObservationRegistry() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class),
mock(BaseAdvisorChain.class), null, mock(ChatClientObservationConvention.class)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("observationRegistry cannot be null");
}
@Test
void buildStreamResponseSpecWithNullObservationConvention() {
assertThatThrownBy(() -> new DefaultChatClient.DefaultStreamResponseSpec(mock(ChatClientRequest.class),
mock(BaseAdvisorChain.class), mock(ObservationRegistry.class), null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("observationConvention cannot be null");
}
@Test
void whenSimplePromptThenFluxChatClientResponse() {
ChatModel chatModel = mock(ChatModel.class);
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
given(chatModel.stream(promptCaptor.capture()))
.willReturn(Flux.just(new ChatResponse(List.of(new Generation(new AssistantMessage("response"))))));
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
.stream();
ChatClientResponse chatClientResponse = spec.chatClientResponse().blockLast();
assertThat(chatClientResponse).isNotNull();
ChatResponse chatResponse = chatClientResponse.chatResponse();
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getResult().getOutput().getText()).isEqualTo("response");
Prompt actualPrompt = promptCaptor.getValue();
assertThat(actualPrompt.getInstructions()).hasSize(1);
assertThat(actualPrompt.getInstructions().get(0).getText()).isEqualTo("my question");
}
@Test
@@ -1084,8 +1186,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
.stream();
ChatResponse chatResponse = spec.chatResponse().blockLast();
assertThat(chatResponse).isNotNull();
@@ -1107,8 +1209,8 @@ class DefaultChatClientTests {
Prompt prompt = new Prompt(new SystemMessage("instructions"), new UserMessage("my question"));
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt(prompt);
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
.stream();
ChatResponse chatResponse = spec.chatResponse().blockLast();
assertThat(chatResponse).isNotNull();
@@ -1132,8 +1234,8 @@ class DefaultChatClientTests {
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt(prompt)
.user("another question");
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
.stream();
ChatResponse chatResponse = spec.chatResponse().blockLast();
assertThat(chatResponse).isNotNull();
@@ -1159,8 +1261,9 @@ class DefaultChatClientTests {
.prompt()
.user("another question")
.messages(messages);
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
.stream();
ChatResponse chatResponse = spec.chatResponse().blockLast();
assertThat(chatResponse).isNotNull();
@@ -1183,8 +1286,8 @@ class DefaultChatClientTests {
ChatClient chatClient = new DefaultChatClientBuilder(chatModel).build();
DefaultChatClient.DefaultChatClientRequestSpec chatClientRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) chatClient
.prompt("my question");
DefaultChatClient.DefaultStreamResponseSpec spec = new DefaultChatClient.DefaultStreamResponseSpec(
chatClientRequestSpec);
DefaultChatClient.DefaultStreamResponseSpec spec = (DefaultChatClient.DefaultStreamResponseSpec) chatClientRequestSpec
.stream();
String content = spec.content().blockLast();
assertThat(content).isNull();

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,7 +21,16 @@ import java.util.Map;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -147,4 +156,58 @@ class AdvisedRequestTests {
.hasMessage("toolContext cannot be null");
}
@Test
void whenConvertToAndFromChatClientRequest() {
ChatModel chatModel = mock(ChatModel.class);
ChatOptions chatOptions = ToolCallingChatOptions.builder().build();
List<Message> messages = List.of(mock(UserMessage.class));
SystemMessage systemMessage = new SystemMessage("Instructions {key}");
UserMessage userMessage = new UserMessage("Question {key}", mock(Media.class));
Map<String, Object> systemParams = Map.of("key", "value");
Map<String, Object> userParams = Map.of("key", "value");
List<String> toolNames = List.of("tool1", "tool2");
ToolCallback toolCallback = mock(ToolCallback.class);
Map<String, Object> toolContext = Map.of("key", "value");
List<Advisor> advisors = List.of(mock(Advisor.class));
Map<String, Object> advisorContext = Map.of("key", "value");
AdvisedRequest advisedRequest = AdvisedRequest.builder()
.chatModel(chatModel)
.chatOptions(chatOptions)
.messages(messages)
.systemText(systemMessage.getText())
.systemParams(systemParams)
.userText(userMessage.getText())
.userParams(userParams)
.media(userMessage.getMedia())
.functionNames(toolNames)
.functionCallbacks(List.of(toolCallback))
.toolContext(toolContext)
.advisors(advisors)
.adviseContext(advisorContext)
.build();
ChatClientRequest chatClientRequest = advisedRequest.toChatClientRequest();
assertThat(chatClientRequest.context().get(ChatClientAttributes.CHAT_MODEL.getKey())).isEqualTo(chatModel);
assertThat(chatClientRequest.prompt().getOptions()).isEqualTo(chatOptions);
assertThat(chatClientRequest.prompt().getInstructions()).hasSize(3);
assertThat(chatClientRequest.prompt().getInstructions().get(0)).isEqualTo(messages.get(0));
assertThat(chatClientRequest.prompt().getInstructions().get(1).getText()).isEqualTo("Instructions value");
assertThat(chatClientRequest.prompt().getInstructions().get(2).getText()).isEqualTo("Question value");
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolNames())
.containsAll(toolNames);
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolCallbacks())
.contains(toolCallback);
assertThat(((ToolCallingChatOptions) chatClientRequest.prompt().getOptions()).getToolContext())
.containsAllEntriesOf(toolContext);
assertThat((List<Advisor>) chatClientRequest.context().get(ChatClientAttributes.ADVISORS.getKey()))
.containsAll(advisors);
assertThat(chatClientRequest.context()).containsAllEntriesOf(advisorContext);
AdvisedRequest convertedAdvisedRequest = AdvisedRequest.from(chatClientRequest);
assertThat(convertedAdvisedRequest.toPrompt()).isEqualTo(chatClientRequest.prompt());
assertThat(convertedAdvisedRequest.adviseContext()).containsAllEntriesOf(chatClientRequest.context());
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ import java.util.Map;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.model.ChatResponse;
import static org.assertj.core.api.Assertions.assertThat;
@@ -67,7 +68,8 @@ class AdvisedResponseTests {
@Test
void whenBuildFromNullAdvisedResponseThenThrows() {
assertThatThrownBy(() -> AdvisedResponse.from(null)).isInstanceOf(IllegalArgumentException.class)
assertThatThrownBy(() -> AdvisedResponse.from((AdvisedResponse) null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("advisedResponse cannot be null");
}
@@ -85,4 +87,16 @@ class AdvisedResponseTests {
.hasMessage("contextTransform cannot be null");
}
@Test
void whenConvertToAndFromChatClientResponse() {
ChatResponse chatResponse = mock(ChatResponse.class);
Map<String, Object> context = Map.of("key", "value");
AdvisedResponse advisedResponse = new AdvisedResponse(chatResponse, context);
ChatClientResponse chatClientResponse = advisedResponse.toChatClientResponse();
AdvisedResponse newAdvisedResponse = AdvisedResponse.from(chatClientResponse);
assertThat(newAdvisedResponse).isEqualTo(advisedResponse);
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,9 +17,13 @@
package org.springframework.ai.chat.client.advisor.observation;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.prompt.Prompt;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;
/**
* Unit tests for {@link AdvisorObservationContext}.
@@ -32,8 +36,7 @@ class AdvisorObservationContextTests {
@Test
void whenMandatoryOptionsThenReturn() {
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName("MyName")
.advisorType(AdvisorObservationContext.Type.BEFORE)
.advisorName("AdvisorName")
.build();
assertThat(observationContext).isNotNull();
@@ -41,17 +44,38 @@ class AdvisorObservationContextTests {
@Test
void missingAdvisorName() {
assertThatThrownBy(
() -> AdvisorObservationContext.builder().advisorType(AdvisorObservationContext.Type.BEFORE).build())
assertThatThrownBy(() -> AdvisorObservationContext.builder().build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("advisorName must not be null or empty");
.hasMessageContaining("advisorName cannot be null or empty");
}
@Test
void missingAdvisorType() {
assertThatThrownBy(() -> AdvisorObservationContext.builder().advisorName("MyName").build())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("advisorType must not be null");
void whenBuilderWithAdvisedRequestThenReturn() {
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName("AdvisorName")
.advisedRequest(mock(AdvisedRequest.class))
.build();
assertThat(observationContext).isNotNull();
}
@Test
void whenBuilderWithChatClientRequestThenReturn() {
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName("AdvisorName")
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build())
.build();
assertThat(observationContext).isNotNull();
}
@Test
void missingBuilderWithBothRequestsThenThrow() {
assertThatThrownBy(() -> AdvisorObservationContext.builder()
.advisedRequest(mock(AdvisedRequest.class))
.chatClientRequest(ChatClientRequest.builder().prompt(new Prompt()).build())
.build()).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("ChatClientRequest and AdvisedRequest cannot be set at the same time");
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -47,7 +47,6 @@ class DefaultAdvisorObservationConventionTests {
void contextualName() {
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName("MyName")
.advisorType(AdvisorObservationContext.Type.AROUND)
.build();
assertThat(this.observationConvention.getContextualName(observationContext)).isEqualTo("my_name");
}
@@ -56,7 +55,6 @@ class DefaultAdvisorObservationConventionTests {
void supportsAdvisorObservationContext() {
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName("MyName")
.advisorType(AdvisorObservationContext.Type.AROUND)
.build();
assertThat(this.observationConvention.supportsContext(observationContext)).isTrue();
assertThat(this.observationConvention.supportsContext(new Observation.Context())).isFalse();
@@ -66,11 +64,8 @@ class DefaultAdvisorObservationConventionTests {
void shouldHaveLowCardinalityKeyValuesWhenDefined() {
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName("MyName")
.advisorType(AdvisorObservationContext.Type.AROUND)
.build();
assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains(
KeyValue.of(LowCardinalityKeyNames.ADVISOR_TYPE.asString(),
AdvisorObservationContext.Type.AROUND.name()),
KeyValue.of(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), AiOperationType.FRAMEWORK.value()),
KeyValue.of(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.SPRING_AI.value()),
KeyValue.of(LowCardinalityKeyNames.ADVISOR_NAME.asString(), "MyName"),
@@ -81,7 +76,6 @@ class DefaultAdvisorObservationConventionTests {
void shouldHaveKeyValuesWhenDefinedAndResponse() {
AdvisorObservationContext observationContext = AdvisorObservationContext.builder()
.advisorName("MyName")
.advisorType(AdvisorObservationContext.Type.AROUND)
.order(678)
.build();

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,20 +16,22 @@
package org.springframework.ai.chat.client.observation;
import java.util.List;
import java.util.Map;
import io.micrometer.common.KeyValue;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import static org.assertj.core.api.Assertions.assertThat;
@@ -57,14 +59,9 @@ class ChatClientInputContentObservationFilterTests {
@Test
void whenEmptyInputContentThenReturnOriginalContext() {
ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
ChatClientObservationConvention customObservationConvention = null;
var request = ChatClientRequest.builder().prompt(new Prompt()).build();
var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of(), observationRegistry, customObservationConvention,
Map.of());
var expectedContext = ChatClientObservationContext.builder().withRequest(request).build();
var expectedContext = ChatClientObservationContext.builder().request(request).build();
var actualContext = this.observationFilter.map(expectedContext);
@@ -73,14 +70,13 @@ class ChatClientInputContentObservationFilterTests {
@Test
void whenWithTextThenAugmentContext() {
ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
ChatClientObservationConvention customObservationConvention = null;
var request = ChatClientRequest.builder()
.prompt(new Prompt(new SystemMessage("sample system text"), new UserMessage("sample user text")))
.context(ChatClientAttributes.USER_PARAMS.getKey(), Map.of("up1", "upv1"))
.context(ChatClientAttributes.SYSTEM_PARAMS.getKey(), Map.of("sp1", "sp1v"))
.build();
var request = new DefaultChatClientRequestSpec(this.chatModel, "sample user text", Map.of("up1", "upv1"),
"sample system text", Map.of("sp1", "sp1v"), List.of(), List.of(), List.of(), List.of(), null,
List.of(), Map.of(), observationRegistry, customObservationConvention, Map.of());
var originalContext = ChatClientObservationContext.builder().withRequest(request).build();
var originalContext = ChatClientObservationContext.builder().request(request).build();
var augmentedContext = this.observationFilter.map(originalContext);

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,17 +16,14 @@
package org.springframework.ai.chat.client.observation;
import java.util.List;
import java.util.Map;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import static org.assertj.core.api.Assertions.assertThat;
@@ -44,11 +41,10 @@ class ChatClientObservationContextTests {
@Test
void whenMandatoryRequestOptionsThenReturn() {
var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());
var observationContext = ChatClientObservationContext.builder().withRequest(request).withStream(true).build();
var observationContext = ChatClientObservationContext.builder()
.request(ChatClientRequest.builder().prompt(new Prompt()).build())
.stream(true)
.build();
assertThat(observationContext).isNotNull();
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,18 +17,17 @@
package org.springframework.ai.chat.client.observation;
import java.util.List;
import java.util.Map;
import io.micrometer.common.KeyValue;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.ChatClientAttributes;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
@@ -36,7 +35,9 @@ import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.observation.conventions.SpringAiKind;
@@ -56,7 +57,7 @@ class DefaultChatClientObservationConventionTests {
@Mock
ChatModel chatModel;
DefaultChatClientRequestSpec request;
ChatClientRequest request;
static CallAroundAdvisor dummyAdvisor(String name) {
return new CallAroundAdvisor() {
@@ -109,8 +110,7 @@ class DefaultChatClientObservationConventionTests {
@BeforeEach
public void beforeEach() {
this.request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(), List.of(),
List.of(), List.of(), List.of(), null, List.of(), Map.of(), ObservationRegistry.NOOP, null, Map.of());
this.request = ChatClientRequest.builder().prompt(new Prompt()).build();
}
@Test
@@ -121,8 +121,8 @@ class DefaultChatClientObservationConventionTests {
@Test
void shouldHaveContextualName() {
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.withRequest(this.request)
.withStream(true)
.request(this.request)
.stream(true)
.build();
assertThat(this.observationConvention.getContextualName(observationContext))
@@ -132,8 +132,8 @@ class DefaultChatClientObservationConventionTests {
@Test
void supportsOnlyChatClientObservationContext() {
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.withRequest(this.request)
.withStream(true)
.request(this.request)
.stream(true)
.build();
assertThat(this.observationConvention.supportsContext(observationContext)).isTrue();
@@ -143,8 +143,8 @@ class DefaultChatClientObservationConventionTests {
@Test
void shouldHaveRequiredKeyValues() {
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.withRequest(this.request)
.withStream(true)
.request(this.request)
.stream(true)
.build();
assertThat(this.observationConvention.getLowCardinalityKeyValues(observationContext)).contains(
@@ -154,27 +154,31 @@ class DefaultChatClientObservationConventionTests {
@Test
void shouldHaveOptionalKeyValues() {
var request = new DefaultChatClientRequestSpec(this.chatModel, "", Map.of(), "", Map.of(),
List.of(dummyFunction("functionCallback1"), dummyFunction("functionCallback2")), List.of(),
List.of("function1", "function2"), List.of(), null,
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")), Map.of("advParam1", "advisorParam1Value"),
ObservationRegistry.NOOP, null, Map.of());
var request = ChatClientRequest.builder()
.prompt(new Prompt("",
ToolCallingChatOptions.builder()
.toolNames("tool1", "tool2")
.toolCallbacks(dummyFunction("toolCallback1"), dummyFunction("toolCallback2"))
.build()))
.context("advParam1", "advisorParam1Value")
.context(ChatClientAttributes.ADVISORS.getKey(),
List.of(dummyAdvisor("advisor1"), dummyAdvisor("advisor2")))
.build();
ChatClientObservationContext observationContext = ChatClientObservationContext.builder()
.withRequest(request)
.request(request)
.withFormat("json")
.withStream(true)
.stream(true)
.build();
assertThat(this.observationConvention.getHighCardinalityKeyValues(observationContext)).contains(
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(),
"[\"advisor1\", \"advisor2\", \"CallAroundAdvisor\", \"StreamAroundAdvisor\"]"),
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISORS.asString(), "[\"advisor1\", \"advisor2\"]"),
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_ADVISOR_PARAMS.asString(),
"[\"advParam1\":\"advisorParam1Value\"]"),
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_NAMES.asString(),
"[\"function1\", \"function2\"]"),
"[\"tool1\", \"tool2\"]"),
KeyValue.of(HighCardinalityKeyNames.CHAT_CLIENT_TOOL_FUNCTION_CALLBACKS.asString(),
"[\"functionCallback1\", \"functionCallback2\"]"));
"[\"toolCallback1\", \"toolCallback2\"]"));
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -148,4 +148,21 @@ class PromptTests {
assertThat(prompt.getInstructions()).isNotSameAs(copiedPrompt.getInstructions());
}
@Test
public void mutatePrompt() {
String template = "Hello, {name}! Your age is {age}.";
Map<String, Object> model = new HashMap<>();
model.put("name", "Alice");
model.put("age", 30);
PromptTemplate promptTemplate = new PromptTemplate(template, model);
ChatOptions chatOptions = ChatOptions.builder().temperature(0.5).maxTokens(100).build();
Prompt prompt = promptTemplate.create(model, chatOptions);
Prompt copiedPrompt = prompt.mutate().build();
assertThat(prompt).isNotSameAs(copiedPrompt);
assertThat(prompt.getOptions()).isNotSameAs(copiedPrompt.getOptions());
assertThat(prompt.getInstructions()).isNotSameAs(copiedPrompt.getInstructions());
}
}

View File

@@ -24,6 +24,7 @@ import java.util.Map;
import java.util.Objects;
import org.springframework.core.io.Resource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
@@ -49,6 +50,7 @@ public abstract class AbstractMessage implements Message {
/**
* The content of the message.
*/
@Nullable
protected final String textContent;
/**
@@ -63,11 +65,12 @@ public abstract class AbstractMessage implements Message {
* @param textContent the text content
* @param metadata the metadata
*/
protected AbstractMessage(MessageType messageType, String textContent, Map<String, Object> metadata) {
protected AbstractMessage(MessageType messageType, @Nullable String textContent, Map<String, Object> metadata) {
Assert.notNull(messageType, "Message type must not be null");
if (messageType == MessageType.SYSTEM || messageType == MessageType.USER) {
Assert.notNull(textContent, "Content must not be null for SYSTEM or USER messages");
}
Assert.notNull(metadata, "Metadata must not be null");
this.messageType = messageType;
this.textContent = textContent;
this.metadata = new HashMap<>(metadata);
@@ -81,7 +84,9 @@ public abstract class AbstractMessage implements Message {
* @param metadata the metadata
*/
protected AbstractMessage(MessageType messageType, Resource resource, Map<String, Object> metadata) {
Assert.notNull(messageType, "Message type must not be null");
Assert.notNull(resource, "Resource must not be null");
Assert.notNull(metadata, "Metadata must not be null");
try (InputStream inputStream = resource.getInputStream()) {
this.textContent = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
}
@@ -98,6 +103,7 @@ public abstract class AbstractMessage implements Message {
* @return the content of the message
*/
@Override
@Nullable
public String getText() {
return this.textContent;
}

View File

@@ -0,0 +1,52 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.messages;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import org.springframework.util.StreamUtils;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
/**
* Utility class for managing messages.
*
* @author Thomas Vitale
*/
final class MessageUtils {
private MessageUtils() {
}
static String readResource(Resource resource) {
return readResource(resource, Charset.defaultCharset());
}
static String readResource(Resource resource, Charset charset) {
Assert.notNull(resource, "resource cannot be null");
Assert.notNull(charset, "charset cannot be null");
try (InputStream inputStream = resource.getInputStream()) {
return StreamUtils.copyToString(inputStream, charset);
}
catch (IOException ex) {
throw new RuntimeException("Failed to read resource", ex);
}
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,10 +16,14 @@
package org.springframework.ai.chat.messages;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.springframework.core.io.Resource;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
/**
* A message of the type 'system' passed as input. The system message gives high level
@@ -31,14 +35,19 @@ import org.springframework.core.io.Resource;
public class SystemMessage extends AbstractMessage {
public SystemMessage(String textContent) {
super(MessageType.SYSTEM, textContent, Map.of());
this(textContent, Map.of());
}
public SystemMessage(Resource resource) {
super(MessageType.SYSTEM, resource, Map.of());
this(MessageUtils.readResource(resource), Map.of());
}
private SystemMessage(String textContent, Map<String, Object> metadata) {
super(MessageType.SYSTEM, textContent, metadata);
}
@Override
@NonNull
public String getText() {
return this.textContent;
}
@@ -68,4 +77,53 @@ public class SystemMessage extends AbstractMessage {
+ ", metadata=" + this.metadata + '}';
}
public SystemMessage copy() {
return new SystemMessage(getText(), Map.copyOf(this.metadata));
}
public Builder mutate() {
return new Builder().text(this.textContent).metadata(this.metadata);
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
@Nullable
private String textContent;
@Nullable
private Resource resource;
private Map<String, Object> metadata = new HashMap<>();
public Builder text(String textContent) {
this.textContent = textContent;
return this;
}
public Builder text(Resource resource) {
this.resource = resource;
return this;
}
public Builder metadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}
public SystemMessage build() {
if (StringUtils.hasText(textContent) && resource != null) {
throw new IllegalArgumentException("textContent and resource cannot be set at the same time");
}
else if (resource != null) {
this.textContent = MessageUtils.readResource(resource);
}
return new SystemMessage(this.textContent, this.metadata);
}
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,13 +19,17 @@ package org.springframework.ai.chat.messages;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.springframework.ai.content.Media;
import org.springframework.ai.content.MediaContent;
import org.springframework.core.io.Resource;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* A message of the type 'user' passed as input Messages with the user role are from the
@@ -37,26 +41,45 @@ public class UserMessage extends AbstractMessage implements MediaContent {
protected final List<Media> media;
public UserMessage(String textContent) {
this(MessageType.USER, textContent, new ArrayList<>(), Map.of());
this(textContent, new ArrayList<>(), Map.of());
}
public UserMessage(Resource resource) {
super(MessageType.USER, resource, Map.of());
this.media = new ArrayList<>();
this(MessageUtils.readResource(resource));
}
/**
* @deprecated use {@link #builder()} instead.
*/
@Deprecated
public UserMessage(String textContent, List<Media> media) {
this(MessageType.USER, textContent, media, Map.of());
}
/**
* @deprecated use {@link #builder()} instead.
*/
@Deprecated
public UserMessage(String textContent, Media... media) {
this(textContent, Arrays.asList(media));
}
public UserMessage(String textContent, Collection<Media> mediaList, Map<String, Object> metadata) {
this(MessageType.USER, textContent, mediaList, metadata);
/**
* @deprecated use {@link #builder()} instead. Will be made private in the next
* release.
*/
@Deprecated
public UserMessage(String textContent, Collection<Media> media, Map<String, Object> metadata) {
super(MessageType.USER, textContent, metadata);
Assert.notNull(media, "media cannot be null");
Assert.noNullElements(media, "media cannot have null elements");
this.media = new ArrayList<>(media);
}
/**
* @deprecated use {@link #builder()} instead.
*/
@Deprecated
public UserMessage(MessageType messageType, String textContent, Collection<Media> media,
Map<String, Object> metadata) {
super(messageType, textContent, metadata);
@@ -71,13 +94,77 @@ public class UserMessage extends AbstractMessage implements MediaContent {
}
@Override
public List<Media> getMedia() {
return this.media;
}
@Override
@NonNull
public String getText() {
return this.textContent;
}
@Override
public List<Media> getMedia() {
return this.media;
}
public UserMessage copy() {
return new UserMessage(getText(), List.copyOf(getMedia()), Map.copyOf(getMetadata()));
}
public Builder mutate() {
return new Builder().text(getText()).media(List.copyOf(getMedia())).metadata(Map.copyOf(getMetadata()));
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
@Nullable
private String textContent;
@Nullable
private Resource resource;
private List<Media> media = new ArrayList<>();
private Map<String, Object> metadata = new HashMap<>();
public Builder text(String textContent) {
this.textContent = textContent;
return this;
}
public Builder text(Resource resource) {
this.resource = resource;
return this;
}
public Builder media(List<Media> media) {
this.media = media;
return this;
}
public Builder media(@Nullable Media... media) {
if (media != null) {
this.media = Arrays.asList(media);
}
return this;
}
public Builder metadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}
public UserMessage build() {
if (StringUtils.hasText(textContent) && resource != null) {
throw new IllegalArgumentException("textContent and resource cannot be set at the same time");
}
else if (resource != null) {
this.textContent = MessageUtils.readResource(resource);
}
return new UserMessage(this.textContent, this.media, this.metadata);
}
}
}

View File

@@ -0,0 +1,22 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@NonNullApi
@NonNullFields
package org.springframework.ai.chat.messages;
import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;

View File

@@ -30,6 +30,8 @@ import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.model.ModelRequest;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/**
* The Prompt class represents a prompt used in AI model requests. A prompt consists of
@@ -62,15 +64,15 @@ public class Prompt implements ModelRequest<List<Message>> {
this(Arrays.asList(messages), null);
}
public Prompt(String contents, ChatOptions chatOptions) {
public Prompt(String contents, @Nullable ChatOptions chatOptions) {
this(new UserMessage(contents), chatOptions);
}
public Prompt(Message message, ChatOptions chatOptions) {
public Prompt(Message message, @Nullable ChatOptions chatOptions) {
this(Collections.singletonList(message), chatOptions);
}
public Prompt(List<Message> messages, ChatOptions chatOptions) {
public Prompt(List<Message> messages, @Nullable ChatOptions chatOptions) {
this.messages = messages;
this.chatOptions = chatOptions;
}
@@ -123,10 +125,17 @@ public class Prompt implements ModelRequest<List<Message>> {
List<Message> messagesCopy = new ArrayList<>();
this.messages.forEach(message -> {
if (message instanceof UserMessage userMessage) {
messagesCopy.add(new UserMessage(userMessage.getText(), userMessage.getMedia(), message.getMetadata()));
messagesCopy.add(UserMessage.builder()
.text(userMessage.getText())
.media(userMessage.getMedia())
.metadata(message.getMetadata())
.build());
}
else if (message instanceof SystemMessage systemMessage) {
messagesCopy.add(new SystemMessage(systemMessage.getText()));
messagesCopy.add(SystemMessage.builder()
.text(systemMessage.getText())
.metadata(systemMessage.getMetadata())
.build());
}
else if (message instanceof AssistantMessage assistantMessage) {
messagesCopy.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(),
@@ -144,4 +153,69 @@ public class Prompt implements ModelRequest<List<Message>> {
return messagesCopy;
}
public Builder mutate() {
Builder builder = new Builder().messages(instructionsCopy());
if (this.chatOptions != null) {
builder.chatOptions(this.chatOptions.copy());
}
return builder;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
@Nullable
private String content;
@Nullable
private List<Message> messages = new ArrayList<>();
@Nullable
private ChatOptions chatOptions;
public Builder content(@Nullable String content) {
this.content = content;
return this;
}
public Builder messages(Message... messages) {
if (messages != null) {
this.messages = Arrays.asList(messages);
}
return this;
}
public Builder messages(List<Message> messages) {
this.messages = messages;
return this;
}
public Builder addMessage(Message message) {
if (this.messages == null) {
this.messages = new ArrayList<>();
}
this.messages.add(message);
return this;
}
public Builder chatOptions(ChatOptions chatOptions) {
this.chatOptions = chatOptions;
return this;
}
public Prompt build() {
if (StringUtils.hasText(this.content) && !CollectionUtils.isEmpty(this.messages)) {
throw new IllegalArgumentException("content and messages cannot be set at the same time");
}
else if (StringUtils.hasText(this.content)) {
this.messages = List.of(new UserMessage(this.content));
}
return new Prompt(this.messages, this.chatOptions);
}
}
}

View File

@@ -0,0 +1,59 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.messages;
import org.junit.jupiter.api.Test;
import org.springframework.core.io.ClassPathResource;
import java.nio.charset.StandardCharsets;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
/**
* Unit tests for {@link MessageUtils}.
*
* @author Thomas Vitale
*/
class MessageUtilsTests {
@Test
void readResource() {
String content = MessageUtils.readResource(new ClassPathResource("prompt-user.txt"));
assertThat(content).isEqualTo("Hello, world!");
}
@Test
void readResourceWhenNull() {
assertThatThrownBy(() -> MessageUtils.readResource(null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("resource cannot be null");
}
@Test
void readResourceWithCharset() {
String content = MessageUtils.readResource(new ClassPathResource("prompt-user.txt"), StandardCharsets.UTF_8);
assertThat(content).isEqualTo("Hello, world!");
}
@Test
void readResourceWithCharsetWhenNull() {
assertThatThrownBy(() -> MessageUtils.readResource(new ClassPathResource("prompt-user.txt"), null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("charset cannot be null");
}
}

View File

@@ -0,0 +1,111 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.messages;
import org.junit.jupiter.api.Test;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
import static org.springframework.ai.chat.messages.AbstractMessage.MESSAGE_TYPE;
/**
* Unit tests for {@link SystemMessage}.
*
* @author Thomas Vitale
*/
class SystemMessageTests {
@Test
void systemMessageWithNullText() {
assertThrows(IllegalArgumentException.class, () -> new SystemMessage((String) null));
}
@Test
void systemMessageWithTextContent() {
String text = "Tell me, did you sail across the sun?";
SystemMessage message = new SystemMessage(text);
assertEquals(text, message.getText());
assertEquals(MessageType.SYSTEM, message.getMetadata().get(MESSAGE_TYPE));
}
@Test
void systemMessageWithNullResource() {
assertThrows(IllegalArgumentException.class, () -> new SystemMessage((Resource) null));
}
@Test
void systemMessageWithResource() {
SystemMessage message = new SystemMessage(new ClassPathResource("prompt-system.txt"));
assertEquals("Tell me, did you sail across the sun?", message.getText());
assertEquals(MessageType.SYSTEM, message.getMetadata().get(MESSAGE_TYPE));
}
@Test
void systemMessageFromBuilderWithText() {
String text = "Tell me, did you sail across the sun?";
SystemMessage message = SystemMessage.builder().text(text).metadata(Map.of("key", "value")).build();
assertEquals(text, message.getText());
assertThat(message.getMetadata()).hasSize(2)
.containsEntry(MESSAGE_TYPE, MessageType.SYSTEM)
.containsEntry("key", "value");
}
@Test
void systemMessageFromBuilderWithResource() {
Resource resource = new ClassPathResource("prompt-system.txt");
SystemMessage message = SystemMessage.builder().text(resource).metadata(Map.of("key", "value")).build();
assertEquals("Tell me, did you sail across the sun?", message.getText());
assertThat(message.getMetadata()).hasSize(2)
.containsEntry(MESSAGE_TYPE, MessageType.SYSTEM)
.containsEntry("key", "value");
}
@Test
void systemMessageCopy() {
String text1 = "Tell me, did you sail across the sun?";
Map<String, Object> metadata1 = Map.of("key", "value");
SystemMessage systemMessage1 = SystemMessage.builder().text(text1).metadata(metadata1).build();
SystemMessage systemMessage2 = systemMessage1.copy();
assertThat(systemMessage2.getText()).isEqualTo(text1);
assertThat(systemMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
}
@Test
void systemMessageMutate() {
String text1 = "Tell me, did you sail across the sun?";
Map<String, Object> metadata1 = Map.of("key", "value");
SystemMessage systemMessage1 = SystemMessage.builder().text(text1).metadata(metadata1).build();
SystemMessage systemMessage2 = systemMessage1.mutate().build();
assertThat(systemMessage2.getText()).isEqualTo(text1);
assertThat(systemMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
String text3 = "Farewell, Aragog!";
SystemMessage systemMessage3 = systemMessage2.mutate().text(text3).build();
assertThat(systemMessage3.getText()).isEqualTo(text3);
assertThat(systemMessage3.getMetadata()).hasSize(2).isNotSameAs(systemMessage2.getMetadata());
}
}

View File

@@ -0,0 +1,127 @@
/*
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.messages;
import org.junit.jupiter.api.Test;
import org.springframework.ai.content.Media;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.util.MimeTypeUtils;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.springframework.ai.chat.messages.AbstractMessage.MESSAGE_TYPE;
/**
* Unit tests for {@link UserMessage}.
*
* @author Thomas Vitale
*/
class UserMessageTests {
@Test
void userMessageWithNullText() {
assertThatThrownBy(() -> new UserMessage((String) null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Content must not be null for SYSTEM or USER messages");
;
}
@Test
void userMessageWithTextContent() {
String text = "Hello, world!";
UserMessage message = new UserMessage(text);
assertThat(message.getText()).isEqualTo(text);
assertThat(message.getMedia()).isEmpty();
assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER);
}
@Test
void userMessageWithNullResource() {
assertThatThrownBy(() -> new UserMessage((Resource) null)).isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("resource cannot be null");
;
}
@Test
void userMessageWithResource() {
UserMessage message = new UserMessage(new ClassPathResource("prompt-user.txt"));
assertThat(message.getText()).isEqualTo("Hello, world!");
assertThat(message.getMedia()).isEmpty();
assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER);
}
@Test
void userMessageFromBuilderWithText() {
String text = "Hello, world!";
UserMessage message = UserMessage.builder()
.text(text)
.media(new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt")))
.metadata(Map.of("key", "value"))
.build();
assertThat(message.getText()).isEqualTo(text);
assertThat(message.getMedia()).hasSize(1);
assertThat(message.getMetadata()).hasSize(2)
.containsEntry(MESSAGE_TYPE, MessageType.USER)
.containsEntry("key", "value");
}
@Test
void userMessageFromBuilderWithResource() {
UserMessage message = UserMessage.builder().text(new ClassPathResource("prompt-user.txt")).build();
assertThat(message.getText()).isEqualTo("Hello, world!");
assertThat(message.getMedia()).isEmpty();
assertThat(message.getMetadata()).hasSize(1).containsEntry(MESSAGE_TYPE, MessageType.USER);
}
@Test
void userMessageCopy() {
String text1 = "Hello, world!";
Media media1 = new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt"));
Map<String, Object> metadata1 = Map.of("key", "value");
UserMessage userMessage1 = UserMessage.builder().text(text1).media(media1).metadata(metadata1).build();
UserMessage userMessage2 = userMessage1.copy();
assertThat(userMessage2.getText()).isEqualTo(text1);
assertThat(userMessage2.getMedia()).hasSize(1).isNotSameAs(metadata1);
assertThat(userMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
}
@Test
void userMessageMutate() {
String text1 = "Hello, world!";
Media media1 = new Media(MimeTypeUtils.TEXT_PLAIN, new ClassPathResource("prompt-user.txt"));
Map<String, Object> metadata1 = Map.of("key", "value");
UserMessage userMessage1 = UserMessage.builder().text(text1).media(media1).metadata(metadata1).build();
UserMessage userMessage2 = userMessage1.mutate().build();
assertThat(userMessage2.getText()).isEqualTo(text1);
assertThat(userMessage2.getMedia()).hasSize(1).isNotSameAs(metadata1);
assertThat(userMessage2.getMetadata()).hasSize(2).isNotSameAs(metadata1);
String text3 = "Farewell, Aragog!";
UserMessage userMessage3 = userMessage2.mutate().text(text3).build();
assertThat(userMessage3.getText()).isEqualTo(text3);
assertThat(userMessage3.getMedia()).hasSize(1).isNotSameAs(metadata1);
assertThat(userMessage3.getMetadata()).hasSize(2).isNotSameAs(metadata1);
}
}

View File

@@ -0,0 +1 @@
Tell me, did you sail across the sun?

View File

@@ -0,0 +1 @@
Hello, world!

View File

@@ -0,0 +1,50 @@
/*
* Copyright 2025 - 2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.test;
/**
* Utility class for escaping curly brackets in strings
*
* @author Christian Tzolov
*
*/
public class CurlyBracketEscaper {
/**
* Escapes all curly brackets in the input string by adding a backslash before them
* @param input The string containing curly brackets to escape
* @return The string with escaped curly brackets
*/
public static String escapeCurlyBrackets(String input) {
if (input == null) {
return null;
}
return input.replace("{", "\\{").replace("}", "\\}");
}
/**
* Unescapes previously escaped curly brackets by removing the backslashes
* @param input The string containing escaped curly brackets
* @return The string with unescaped curly brackets
*/
public static String unescapeCurlyBrackets(String input) {
if (input == null) {
return null;
}
return input.replace("\\{", "{").replace("\\}", "}");
}
}