Add ChatClient support for returning ResponseEntity<ChatResponse, T>

ChatClient already provides the .chatResponse() method to return the entire ChatResponse instance.
 It also provides a set of overloaded .entity(Type) methods to provide Type-converted responses.
 The new .responseEntity(Type) method returns a ResponseEntity<ChatResponse, T> instance, encapsulating
 both the ChatResponse and the requested Type-converted response entity.

 This change allows for more flexibility when handling different response types and facilitates
 easier integration with other components that expect ResponseEntity instances.
This commit is contained in:
Christian Tzolov
2024-06-03 09:04:34 +02:00
parent 6ad36b7653
commit d6a0dffd3e
4 changed files with 206 additions and 3 deletions

View File

@@ -128,6 +128,12 @@ public interface ChatClient {
String content();
<T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type);
<T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type);
<T> ResponseEntity<ChatResponse, T> responseEntity(StructuredOutputConverter<T> structuredOutputConverter);
}
interface StreamResponseSpec {
@@ -205,9 +211,6 @@ public interface ChatClient {
ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);
// ChatClientRequestSpec adviseOnRequest(ChatClientRequestSpec inputRequest,
// Map<String, Object> context);
CallResponseSpec call();
StreamResponseSpec stream();

View File

@@ -250,6 +250,28 @@ public class DefaultChatClient implements ChatClient {
this.request = request;
}
public <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type) {
Assert.notNull(type, "the class must be non-null");
return doResponseEntity(new BeanOutputConverter<T>(type));
}
public <T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type) {
return doResponseEntity(new BeanOutputConverter<T>(type));
}
public <T> ResponseEntity<ChatResponse, T> responseEntity(
StructuredOutputConverter<T> structuredOutputConverter) {
return doResponseEntity(structuredOutputConverter);
}
protected <T> ResponseEntity<ChatResponse, T> doResponseEntity(StructuredOutputConverter<T> boc) {
var chatResponse = doGetChatResponse(this.request, boc.getFormat());
var responseContent = chatResponse.getResult().getOutput().getContent();
T entity = boc.convert(responseContent);
return new ResponseEntity<>(chatResponse, entity);
}
public <T> T entity(ParameterizedTypeReference<T> type) {
return doSingleWithBeanOutputConverter(new BeanOutputConverter<T>(type));
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
/**
* Represents a {@link org.springframework.ai.model.Model} response that includes the
* entire response along withe specified response entity type.
*
* @param <R> the entire response type.
* @param <E> the converted entity type.
* @author Christian Tzolov
* @since 1.0.0
*/
public record ResponseEntity<R, E>(R response, E entity) {
public R getResponse() {
return this.response;
}
public E getEntity() {
return this.entity;
}
}

View File

@@ -0,0 +1,141 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.chat.client;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.core.ParameterizedTypeReference;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.when;
/**
* @author Christian Tzolov
*/
@ExtendWith(MockitoExtension.class)
public class ChatClientResponseEntityTests {
@Mock
ChatModel chatModel;
@Captor
ArgumentCaptor<Prompt> promptCaptor;
record MyBean(String name, int age) {
}
@Test
public void responseEntityTest() {
ChatResponseMetadata metadata = new DefaultChatResponseMetadata();
metadata.put("key1", "value1");
var chatResponse = new ChatResponse(List.of(new Generation("""
{"name":"John", "age":30}
""")), metadata);
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);
ResponseEntity<ChatResponse, MyBean> responseEntity = ChatClient.builder(chatModel)
.build()
.prompt()
.user("Tell me about John")
.call()
.responseEntity(MyBean.class);
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1");
assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30));
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
assertThat(userMessage.getContent()).contains("Tell me about John");
}
@Test
public void parametrizedResponseEntityTest() {
var chatResponse = new ChatResponse(List.of(new Generation("""
[
{"name":"Max", "age":10},
{"name":"Adi", "age":13}
]
""")));
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);
ResponseEntity<ChatResponse, List<MyBean>> responseEntity = ChatClient.builder(chatModel)
.build()
.prompt()
.user("Tell me about them")
.call()
.responseEntity(new ParameterizedTypeReference<List<MyBean>>() {
});
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getEntity().get(0)).isEqualTo(new MyBean("Max", 10));
assertThat(responseEntity.getEntity().get(1)).isEqualTo(new MyBean("Adi", 13));
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
assertThat(userMessage.getContent()).contains("Tell me about them");
}
@Test
public void customSoCResponseEntityTest() {
var chatResponse = new ChatResponse(List.of(new Generation("""
{"name":"Max", "age":10},
""")));
when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse);
ResponseEntity<ChatResponse, Map<String, Object>> responseEntity = ChatClient.builder(chatModel)
.build()
.prompt()
.user("Tell me about Max")
.call()
.responseEntity(new MapOutputConverter());
assertThat(responseEntity.getResponse()).isEqualTo(chatResponse);
assertThat(responseEntity.getEntity().get("name")).isEqualTo("Max");
assertThat(responseEntity.getEntity().get("age")).isEqualTo(10);
Message userMessage = promptCaptor.getValue().getInstructions().get(0);
assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER);
assertThat(userMessage.getContent()).contains("Tell me about Max");
}
}