From d6a0dffd3ec4eea7dd1489b1ced2275ee868e98d Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Mon, 3 Jun 2024 09:04:34 +0200 Subject: [PATCH] Add ChatClient support for returning ResponseEntity ChatClient already provides the .chatResponse() method to return the entire ChatResponse instance. It also provides a set of overloaded .entity(Type) methods to provide Type-converted responses. The new .responseEntity(Type) method returns a ResponseEntity instance, encapsulating both the ChatResponse and the requested Type-converted response entity. This change allows for more flexibility when handling different response types and facilitates easier integration with other components that expect ResponseEntity instances. --- .../ai/chat/client/ChatClient.java | 9 +- .../ai/chat/client/DefaultChatClient.java | 22 +++ .../ai/chat/client/ResponseEntity.java | 37 +++++ .../client/ChatClientResponseEntityTests.java | 141 ++++++++++++++++++ 4 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index add0d0416..65ac3c4c3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -128,6 +128,12 @@ public interface ChatClient { String content(); + ResponseEntity responseEntity(Class type); + + ResponseEntity responseEntity(ParameterizedTypeReference type); + + ResponseEntity responseEntity(StructuredOutputConverter structuredOutputConverter); + } interface StreamResponseSpec { @@ -205,9 +211,6 @@ public interface ChatClient { ChatClientRequestSpec user(Consumer consumer); - // ChatClientRequestSpec adviseOnRequest(ChatClientRequestSpec inputRequest, - // Map context); - CallResponseSpec call(); StreamResponseSpec stream(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index ffb4eb2cf..3748b42a6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -250,6 +250,28 @@ public class DefaultChatClient implements ChatClient { this.request = request; } + public ResponseEntity responseEntity(Class type) { + Assert.notNull(type, "the class must be non-null"); + return doResponseEntity(new BeanOutputConverter(type)); + } + + public ResponseEntity responseEntity(ParameterizedTypeReference type) { + return doResponseEntity(new BeanOutputConverter(type)); + } + + public ResponseEntity responseEntity( + StructuredOutputConverter structuredOutputConverter) { + return doResponseEntity(structuredOutputConverter); + } + + protected ResponseEntity doResponseEntity(StructuredOutputConverter boc) { + var chatResponse = doGetChatResponse(this.request, boc.getFormat()); + var responseContent = chatResponse.getResult().getOutput().getContent(); + T entity = boc.convert(responseContent); + + return new ResponseEntity<>(chatResponse, entity); + } + public T entity(ParameterizedTypeReference type) { return doSingleWithBeanOutputConverter(new BeanOutputConverter(type)); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java new file mode 100644 index 000000000..476fa7833 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ResponseEntity.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +/** + * Represents a {@link org.springframework.ai.model.Model} response that includes the + * entire response along withe specified response entity type. + * + * @param the entire response type. + * @param the converted entity type. + * @author Christian Tzolov + * @since 1.0.0 + */ +public record ResponseEntity(R response, E entity) { + + public R getResponse() { + return this.response; + } + + public E getEntity() { + return this.entity; + } +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java new file mode 100644 index 000000000..2fae3d9f2 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/ChatClientResponseEntityTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.metadata.ChatResponseMetadata.DefaultChatResponseMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.core.ParameterizedTypeReference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; + +/** + * @author Christian Tzolov + */ +@ExtendWith(MockitoExtension.class) +public class ChatClientResponseEntityTests { + + @Mock + ChatModel chatModel; + + @Captor + ArgumentCaptor promptCaptor; + + record MyBean(String name, int age) { + } + + @Test + public void responseEntityTest() { + + ChatResponseMetadata metadata = new DefaultChatResponseMetadata(); + metadata.put("key1", "value1"); + + var chatResponse = new ChatResponse(List.of(new Generation(""" + {"name":"John", "age":30} + """)), metadata); + + when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + + ResponseEntity responseEntity = ChatClient.builder(chatModel) + .build() + .prompt() + .user("Tell me about John") + .call() + .responseEntity(MyBean.class); + + assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); + assertThat(responseEntity.getResponse().getMetadata().get("key1")).isEqualTo("value1"); + + assertThat(responseEntity.getEntity()).isEqualTo(new MyBean("John", 30)); + + Message userMessage = promptCaptor.getValue().getInstructions().get(0); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getContent()).contains("Tell me about John"); + } + + @Test + public void parametrizedResponseEntityTest() { + + var chatResponse = new ChatResponse(List.of(new Generation(""" + [ + {"name":"Max", "age":10}, + {"name":"Adi", "age":13} + ] + """))); + + when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + + ResponseEntity> responseEntity = ChatClient.builder(chatModel) + .build() + .prompt() + .user("Tell me about them") + .call() + .responseEntity(new ParameterizedTypeReference>() { + }); + + assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); + assertThat(responseEntity.getEntity().get(0)).isEqualTo(new MyBean("Max", 10)); + assertThat(responseEntity.getEntity().get(1)).isEqualTo(new MyBean("Adi", 13)); + + Message userMessage = promptCaptor.getValue().getInstructions().get(0); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getContent()).contains("Tell me about them"); + } + + @Test + public void customSoCResponseEntityTest() { + + var chatResponse = new ChatResponse(List.of(new Generation(""" + {"name":"Max", "age":10}, + """))); + + when(chatModel.call(promptCaptor.capture())).thenReturn(chatResponse); + + ResponseEntity> responseEntity = ChatClient.builder(chatModel) + .build() + .prompt() + .user("Tell me about Max") + .call() + .responseEntity(new MapOutputConverter()); + + assertThat(responseEntity.getResponse()).isEqualTo(chatResponse); + assertThat(responseEntity.getEntity().get("name")).isEqualTo("Max"); + assertThat(responseEntity.getEntity().get("age")).isEqualTo(10); + + Message userMessage = promptCaptor.getValue().getInstructions().get(0); + assertThat(userMessage.getMessageType()).isEqualTo(MessageType.USER); + assertThat(userMessage.getContent()).contains("Tell me about Max"); + } + +}