Add ChatClient support for returning ResponseEntity<ChatResponse, T>
ChatClient already provides the .chatResponse() method to return the entire ChatResponse instance. It also provides a set of overloaded .entity(Type) methods to provide Type-converted responses. The new .responseEntity(Type) method returns a ResponseEntity<ChatResponse, T> instance, encapsulating both the ChatResponse and the requested Type-converted response entity. This change allows for more flexibility when handling different response types and facilitates easier integration with other components that expect ResponseEntity instances.
This commit is contained in:
@@ -128,6 +128,12 @@ public interface ChatClient {
|
||||
|
||||
String content();
|
||||
|
||||
<T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type);
|
||||
|
||||
<T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type);
|
||||
|
||||
<T> ResponseEntity<ChatResponse, T> responseEntity(StructuredOutputConverter<T> structuredOutputConverter);
|
||||
|
||||
}
|
||||
|
||||
interface StreamResponseSpec {
|
||||
@@ -205,9 +211,6 @@ public interface ChatClient {
|
||||
|
||||
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
|
||||
|
||||
// ChatClientRequestSpec adviseOnRequest(ChatClientRequestSpec inputRequest,
|
||||
// Map<String, Object> context);
|
||||
|
||||
CallResponseSpec call();
|
||||
|
||||
StreamResponseSpec stream();
|
||||
|
||||
@@ -250,6 +250,28 @@ public class DefaultChatClient implements ChatClient {
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
public <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type) {
|
||||
Assert.notNull(type, "the class must be non-null");
|
||||
return doResponseEntity(new BeanOutputConverter<T>(type));
|
||||
}
|
||||
|
||||
public <T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type) {
|
||||
return doResponseEntity(new BeanOutputConverter<T>(type));
|
||||
}
|
||||
|
||||
public <T> ResponseEntity<ChatResponse, T> responseEntity(
|
||||
StructuredOutputConverter<T> structuredOutputConverter) {
|
||||
return doResponseEntity(structuredOutputConverter);
|
||||
}
|
||||
|
||||
protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> boc) {
|
||||
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
|
||||
var responseContent = chatResponse.getResult().getOutput().getContent();
|
||||
T entity = boc.convert(responseContent);
|
||||
|
||||
return new ResponseEntity<>(chatResponse, entity);
|
||||
}
|
||||
|
||||
public <T> T entity(ParameterizedTypeReference<T> type) {
|
||||
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright 2024-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
/**
|
||||
* Represents a {@link org.springframework.ai.model.Model} response that includes the
|
||||
* entire response along withe specified response entity type.
|
||||
*
|
||||
* @param <R> the entire response type.
|
||||
* @param <E> the converted entity type.
|
||||
* @author Christian Tzolov
|
||||
* @since 1.0.0
|
||||
*/
|
||||
public record ResponseEntity<R, E>(R response, E entity) {
|
||||
|
||||
public R getResponse() {
|
||||
return this.response;
|
||||
}
|
||||
|
||||
public E getEntity() {
|
||||
return this.entity;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
/*
|
||||
* Copyright 2024-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.chat.client;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Captor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.converter.MapOutputConverter;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
/**
|
||||
* @author Christian Tzolov
|
||||
*/
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class ChatClientResponseEntityTests {
|
||||
|
||||
@Mock
|
||||
ChatModel chatModel;
|
||||
|
||||
@Captor
|
||||
ArgumentCaptor<Prompt> promptCaptor;
|
||||
|
||||
record MyBean(String name, int age) {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void responseEntityTest() {
|
||||
|
||||
ChatResponseMetadata metadata = new DefaultChatResponseMetadata();
|
||||
metadata.put("key1", "value1");
|
||||
|
||||
var chatResponse = new ChatResponse(List.of(new Generation("""
|
||||
{"name":"John", "age":30}
|
||||
""")), metadata);
|
||||
|
||||
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);
|
||||
|
||||
ResponseEntity<ChatResponse, MyBean> responseEntity = ChatClient.builder(chatModel)
|
||||
.build()
|
||||
.prompt()
|
||||
.user("Tell me about John")
|
||||
.call()
|
||||
.responseEntity(MyBean.class);
|
||||
|
||||
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
|
||||
assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1");
|
||||
|
||||
assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));
|
||||
|
||||
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
|
||||
assertThat(userMessage.getContent()).contains("Tell me about John");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void parametrizedResponseEntityTest() {
|
||||
|
||||
var chatResponse = new ChatResponse(List.of(new Generation("""
|
||||
[
|
||||
{"name":"Max", "age":10},
|
||||
{"name":"Adi", "age":13}
|
||||
]
|
||||
""")));
|
||||
|
||||
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);
|
||||
|
||||
ResponseEntity<ChatResponse, List<MyBean>> responseEntity = ChatClient.builder(chatModel)
|
||||
.build()
|
||||
.prompt()
|
||||
.user("Tell me about them")
|
||||
.call()
|
||||
.responseEntity(new ParameterizedTypeReference<List<MyBean>>() {
|
||||
});
|
||||
|
||||
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
|
||||
assertThat(responseEntity.getEntity().get(0)).isEqualTo(new MyBean("Max", 10));
|
||||
assertThat(responseEntity.getEntity().get(1)).isEqualTo(new MyBean("Adi", 13));
|
||||
|
||||
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
|
||||
assertThat(userMessage.getContent()).contains("Tell me about them");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void customSoCResponseEntityTest() {
|
||||
|
||||
var chatResponse = new ChatResponse(List.of(new Generation("""
|
||||
{"name":"Max", "age":10},
|
||||
""")));
|
||||
|
||||
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);
|
||||
|
||||
ResponseEntity<ChatResponse, Map<String, Object>> responseEntity = ChatClient.builder(chatModel)
|
||||
.build()
|
||||
.prompt()
|
||||
.user("Tell me about Max")
|
||||
.call()
|
||||
.responseEntity(new MapOutputConverter());
|
||||
|
||||
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
|
||||
assertThat(responseEntity.getEntity().get("name")).isEqualTo("Max");
|
||||
assertThat(responseEntity.getEntity().get("age")).isEqualTo(10);
|
||||
|
||||
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
|
||||
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
|
||||
assertThat(userMessage.getContent()).contains("Tell me about Max");
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user