From bfafa8bde3bfeef3afc531aa3e48d6e5bd12c994 Mon Sep 17 00:00:00 2001 From: Mark Pollack Date: Fri, 18 Aug 2023 12:47:38 -0400 Subject: [PATCH] Reactor PromptTemplateActions * PromptTemplateActions contains Prompt create() methods * PromptTemplateStringActions contains String render() methods * PromptTemplateMessageActions contains Message createMessage() methods * PromptTemplateChatActions contains List createMessages() actions * Message classes can accept a Spring resource in their constructors * AiClient implementations package name change llm->client * Add toString() to Generation * Add -PintegrationTest profile, disabled by default. * Add integration test for OpenAi Client and 'evaluation' * Add Question and Answer Prompts for evaluation of AiClient responses --- .github/workflows/continuous-integration.yml | 2 +- pom.xml | 46 +++++++------ .../{llm => client}/AzureOpenAiClient.java | 6 +- .../springframework/ai/client/Generation.java | 5 ++ .../ai/prompt/ChatPromptTemplate.java | 6 +- .../ai/prompt/PromptTemplate.java | 30 +++++++-- .../ai/prompt/PromptTemplateActions.java | 13 +--- .../ai/prompt/PromptTemplateChatActions.java | 14 ++++ .../prompt/PromptTemplateMessageActions.java | 13 ++++ .../prompt/PromptTemplateStringActions.java | 11 ++++ .../ai/prompt/SystemPromptTemplate.java | 15 +++-- .../ai/prompt/messages/AbstractMessage.java | 34 +++++++++- .../ai/prompt/messages/SystemMessage.java | 6 +- .../ai/prompt/messages/UserMessage.java | 6 +- .../openai/{llm => client}/OpenAiClient.java | 51 +++++++++++---- .../ai/openai/OpenAiTestConfiguration.java | 8 +-- .../openai/client/ClientIntegrationTests.java | 64 +++++++++++++++++++ .../src/test/{java => }/resources/bikes.json | 0 .../prompts/system-evaluator-message.st | 3 + .../test/resources/prompts/system-message.st | 4 ++ .../prompts/user-evaluator-message.st | 6 ++ .../openai/AzureOpenAiAutoConfiguration.java | 2 +- .../openai/OpenAiAutoConfiguration.java | 2 +- 23 files changed, 276 insertions(+), 71 deletions(-) rename spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/{llm => client}/AzureOpenAiClient.java (92%) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java rename spring-ai-openai/src/main/java/org/springframework/ai/openai/{llm => client}/OpenAiClient.java (72%) create mode 100644 spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java rename spring-ai-openai/src/test/{java => }/resources/bikes.json (100%) create mode 100644 spring-ai-openai/src/test/resources/prompts/system-evaluator-message.st create mode 100644 spring-ai-openai/src/test/resources/prompts/system-message.st create mode 100644 spring-ai-openai/src/test/resources/prompts/user-evaluator-message.st diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index de8f25fb3..3dba0f0e3 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -24,7 +24,7 @@ jobs: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }} ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }} - run: mvn -s settings.xml --batch-mode --update-snapshots deploy + run: mvn -s settings.xml -Pintegration-tests --batch-mode --update-snapshots deploy - name: Generate Java docs run: mvn javadoc:aggregate diff --git a/pom.xml b/pom.xml index 899299151..9e381cce4 100644 --- a/pom.xml +++ b/pom.xml @@ -131,24 +131,6 @@ - - org.apache.maven.plugins - maven-failsafe-plugin - ${maven-failsafe-plugin.version} - - - **/*IntegrationTests.java - - - - - - integration-test - verify - - - - org.apache.maven.plugins maven-jar-plugin @@ -253,6 +235,34 @@ + + integration-tests + + false + + + + + org.apache.maven.plugins + maven-failsafe-plugin + ${maven-failsafe-plugin.version} + + + **/*IntegrationTests.java + + + + + + integration-test + verify + + + + + + + test-coverage diff --git a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/llm/AzureOpenAiClient.java b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java similarity index 92% rename from spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/llm/AzureOpenAiClient.java rename to spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java index eef46e529..c93877556 100644 --- a/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/llm/AzureOpenAiClient.java +++ b/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.azure.openai.llm; +package org.springframework.ai.azure.openai.client; import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.models.*; @@ -56,7 +56,9 @@ public class AzureOpenAiClient implements AiClient { options.setTemperature(this.getTemperature()); options.setModel(this.getModel()); + logger.trace("Azure Chat Message: ", azureChatMessage); ChatCompletions chatCompletions = this.msoftOpenAiClient.getChatCompletions(this.getModel(), options); + logger.trace("Azure ChatCompletions: ", chatCompletions); StringBuilder sb = new StringBuilder(); for (ChatChoice choice : chatCompletions.getChoices()) { if (choice.getMessage() != null && choice.getMessage().getContent() != null) { @@ -78,7 +80,9 @@ public class AzureOpenAiClient implements AiClient { ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages); options.setTemperature(this.getTemperature()); options.setModel(this.getModel()); + logger.trace("Azure ChatCompletionsOptions: ", options); ChatCompletions chatCompletions = this.msoftOpenAiClient.getChatCompletions(this.getModel(), options); + logger.trace("Azure ChatCompletions: ", chatCompletions); List generations = new ArrayList<>(); for (ChatChoice choice : chatCompletions.getChoices()) { ChatMessage choiceMessage = choice.getMessage(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java b/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java index 0c345cc3b..73be362bb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/client/Generation.java @@ -43,4 +43,9 @@ public class Generation { return Collections.unmodifiableMap(this.info); } + @Override + public String toString() { + return "Generation{" + "text='" + text + '\'' + ", info=" + info + '}'; + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java index b31adcc8f..dbef2e4ea 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/ChatPromptTemplate.java @@ -26,7 +26,7 @@ import java.util.Map; * A PromptTemplate that lets you specify the role as a string should the current * implementations and their roles not suffice for your needs. */ -public class ChatPromptTemplate implements PromptTemplateActions { +public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplateChatActions { private final List promptTemplates; @@ -56,7 +56,7 @@ public class ChatPromptTemplate implements PromptTemplateActions { public List createMessages() { List messages = new ArrayList<>(); for (PromptTemplate promptTemplate : promptTemplates) { - messages.addAll(promptTemplate.createMessages()); + messages.add(promptTemplate.createMessage()); } return messages; } @@ -65,7 +65,7 @@ public class ChatPromptTemplate implements PromptTemplateActions { public List createMessages(Map model) { List messages = new ArrayList<>(); for (PromptTemplate promptTemplate : promptTemplates) { - messages.addAll(promptTemplate.createMessages(model)); + messages.add(promptTemplate.createMessage(model)); } return messages; } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java index 5bfa134f5..500597a75 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplate.java @@ -33,7 +33,7 @@ import java.util.Map.Entry; import java.util.stream.Collectors; import java.util.stream.IntStream; -public class PromptTemplate implements PromptTemplateActions { +public class PromptTemplate implements PromptTemplateActions, PromptTemplateStringActions, PromptTemplateMessageActions { private ST st; @@ -85,6 +85,25 @@ public class PromptTemplate implements PromptTemplateActions { } } + public PromptTemplate(Resource resource, Map model) { + try (InputStream inputStream = resource.getInputStream()) { + this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } + // If the template string is not valid, an exception will be thrown + try { + this.st = new ST(this.template, '{', '}'); + for (Entry entry : model.entrySet()) { + add(entry.getKey(), entry.getValue()); + } + } + catch (Exception ex) { + throw new IllegalArgumentException("The template string is not valid.", ex); + } + } + public OutputParser getOutputParser() { return outputParser; } @@ -108,6 +127,7 @@ public class PromptTemplate implements PromptTemplateActions { } // Render Methods + @Override public String render() { return st.render(); } @@ -138,13 +158,13 @@ public class PromptTemplate implements PromptTemplateActions { } @Override - public List createMessages() { - return List.of(new UserMessage(render())); + public Message createMessage() { + return new UserMessage(render()); } @Override - public List createMessages(Map model) { - return List.of(new UserMessage(render(model))); + public Message createMessage(Map model) { + return new UserMessage(render(model)); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java index 76dbea3c3..e1637384f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateActions.java @@ -16,20 +16,9 @@ package org.springframework.ai.prompt; -import org.springframework.ai.prompt.messages.Message; - -import java.util.List; import java.util.Map; -public interface PromptTemplateActions { - - String render(); - - String render(Map model); - - List createMessages(); - - List createMessages(Map model); +public interface PromptTemplateActions extends PromptTemplateStringActions { Prompt create(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java new file mode 100644 index 000000000..b12384c65 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateChatActions.java @@ -0,0 +1,14 @@ +package org.springframework.ai.prompt; + +import org.springframework.ai.prompt.messages.Message; + +import java.util.List; +import java.util.Map; + +public interface PromptTemplateChatActions { + + List createMessages(); + + List createMessages(Map model); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java new file mode 100644 index 000000000..bf48bf546 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateMessageActions.java @@ -0,0 +1,13 @@ +package org.springframework.ai.prompt; + +import org.springframework.ai.prompt.messages.Message; + +import java.util.Map; + +public interface PromptTemplateMessageActions { + + Message createMessage(); + + Message createMessage(Map model); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java new file mode 100644 index 000000000..8c041e81c --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/PromptTemplateStringActions.java @@ -0,0 +1,11 @@ +package org.springframework.ai.prompt; + +import java.util.Map; + +public interface PromptTemplateStringActions { + + String render(); + + String render(Map model); + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java index dd06afaf2..aba74047a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/SystemPromptTemplate.java @@ -16,15 +16,12 @@ package org.springframework.ai.prompt; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.Charset; +import java.util.List; import java.util.Map; +import org.springframework.ai.prompt.messages.Message; import org.springframework.ai.prompt.messages.SystemMessage; import org.springframework.core.io.Resource; -import org.springframework.util.StreamUtils; -import org.stringtemplate.v4.ST; public class SystemPromptTemplate extends PromptTemplate { @@ -36,7 +33,13 @@ public class SystemPromptTemplate extends PromptTemplate { super(resource); } - public SystemMessage createMessage(Map model) { + @Override + public Message createMessage() { + return new SystemMessage(render()); + } + + @Override + public Message createMessage(Map model) { return new SystemMessage(render(model)); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java index f5af9df2e..4e43cf4b5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java @@ -16,6 +16,12 @@ package org.springframework.ai.prompt.messages; +import org.springframework.core.io.Resource; +import org.springframework.util.StreamUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; import java.util.HashMap; import java.util.Map; @@ -23,6 +29,9 @@ public abstract class AbstractMessage implements Message { protected String content; + /** + * Additional options for the message to influence the response, not a model map. + */ protected Map properties = new HashMap<>(); protected MessageType messageType; @@ -36,10 +45,31 @@ public abstract class AbstractMessage implements Message { this.content = content; } - protected AbstractMessage(MessageType messageType, String content, Map properties) { + protected AbstractMessage(MessageType messageType, String content, Map messageProperties) { this.messageType = messageType; this.content = content; - this.properties = properties; + this.properties = messageProperties; + } + + protected AbstractMessage(MessageType messageType, Resource resource) { + this.messageType = messageType; + try (InputStream inputStream = resource.getInputStream()) { + this.content = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } + } + + protected AbstractMessage(MessageType messageType, Resource resource, Map messagePropertiets) { + this.messageType = messageType; + this.properties = messagePropertiets; + try (InputStream inputStream = resource.getInputStream()) { + this.content = StreamUtils.copyToString(inputStream, Charset.defaultCharset()); + } + catch (IOException ex) { + throw new RuntimeException("Failed to read resource", ex); + } } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java index a9e2b0e3c..451c540df 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/SystemMessage.java @@ -16,6 +16,8 @@ package org.springframework.ai.prompt.messages; +import org.springframework.core.io.Resource; + import java.util.Map; /** @@ -30,8 +32,8 @@ public class SystemMessage extends AbstractMessage { super(MessageType.SYSTEM, content); } - public SystemMessage(String content, Map properties) { - super(MessageType.SYSTEM, content, properties); + public SystemMessage(Resource resource) { + super(MessageType.SYSTEM, resource); } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java index 106f6ce66..64767e681 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/UserMessage.java @@ -16,6 +16,8 @@ package org.springframework.ai.prompt.messages; +import org.springframework.core.io.Resource; + import java.util.Map; /** @@ -29,8 +31,8 @@ public class UserMessage extends AbstractMessage { super(MessageType.USER, message); } - public UserMessage(String message, Map properties) { - super(MessageType.USER, message, properties); + public UserMessage(Resource resource) { + super(MessageType.USER, resource); } @Override diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/llm/OpenAiClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java similarity index 72% rename from spring-ai-openai/src/main/java/org/springframework/ai/openai/llm/OpenAiClient.java rename to spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java index 775ec70e8..1f84279ac 100644 --- a/spring-ai-openai/src/main/java/org/springframework/ai/openai/llm/OpenAiClient.java +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java @@ -14,24 +14,24 @@ * limitations under the License. */ -package org.springframework.ai.openai.llm; - -import java.util.ArrayList; -import java.util.List; +package org.springframework.ai.openai.client; +import com.theokanning.openai.completion.chat.ChatCompletionChoice; import com.theokanning.openai.completion.chat.ChatCompletionRequest; import com.theokanning.openai.completion.chat.ChatMessage; import com.theokanning.openai.service.OpenAiService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.ai.client.AiClient; import org.springframework.ai.client.AiResponse; +import org.springframework.ai.client.Generation; import org.springframework.ai.prompt.Prompt; - import org.springframework.ai.prompt.messages.Message; import org.springframework.util.Assert; +import java.util.ArrayList; +import java.util.List; + /** * Implementation of {@link AiClient} backed by an OpenAiService */ @@ -76,19 +76,49 @@ public class OpenAiClient implements AiClient { @Override public AiResponse generate(Prompt prompt) { - List chatCompletionRequests = getChatCompletionRequest(prompt); - return getLLMResult(chatCompletionRequests); + List messages = prompt.getMessages(); + List theoMessages = new ArrayList<>(); + for (Message message : messages) { + String messageType = message.getMessageType().getValue(); + theoMessages.add(new ChatMessage(messageType, message.getContent())); + } + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() + .model(this.model) + .temperature(this.temperature) + .messages(theoMessages) + .build(); + return getAiResponse(chatCompletionRequest); } private ChatCompletionRequest getChatCompletionRequest(String text) { + List chatMessages = List.of(new ChatMessage("user", text)); + logger.trace("ChatMessages: ", chatMessages); ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() .model(this.model) .temperature(this.temperature) .messages(List.of(new ChatMessage("user", text))) .build(); + logger.trace("ChatCompletionRequest: ", chatCompletionRequest); return chatCompletionRequest; } + private AiResponse getAiResponse(ChatCompletionRequest chatCompletionRequest) { + List generations = new ArrayList<>(); + logger.trace("ChatMessages: ", chatCompletionRequest.getMessages()); + List chatCompletionChoices = this.openAiService + .createChatCompletion(chatCompletionRequest) + .getChoices(); + logger.trace("ChatCompletionChoice: ", chatCompletionChoices); + for (ChatCompletionChoice chatCompletionChoice : chatCompletionChoices) { + ChatMessage chatMessage = chatCompletionChoice.getMessage(); + // TODO investigate mapping of additional metadata/runtime info to the + // general model. + Generation generation = new Generation(chatMessage.getContent()); + generations.add(generation); + } + return new AiResponse(generations); + } + private String getResponse(ChatCompletionRequest chatCompletionRequest) { StringBuilder builder = new StringBuilder(); this.openAiService.createChatCompletion(chatCompletionRequest).getChoices().forEach(choice -> { @@ -99,11 +129,6 @@ public class OpenAiClient implements AiClient { return response; } - private AiResponse getLLMResult(List chatCompletionRequest) { - // TODO - throw new RuntimeException("LLMResult getLLMResult not yet implemented"); - } - private List getChatCompletionRequest(Prompt prompt) { List chatCompletionRequests = new ArrayList<>(); diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index fbe808239..34ccb902d 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -2,26 +2,26 @@ package org.springframework.ai.openai; import com.theokanning.openai.service.OpenAiService; import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient; -import org.springframework.ai.openai.llm.OpenAiClient; +import org.springframework.ai.openai.client.OpenAiClient; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; import java.io.IOException; +import java.time.Duration; @SpringBootConfiguration public class OpenAiTestConfiguration { @Bean public OpenAiService theoOpenAiService() throws IOException { - // get api token in file ~/.openai String apiKey = System.getenv("OPENAI_API_KEY"); - if (!StringUtils.hasText(apiKey)) { throw new IllegalArgumentException( "You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY"); } - return new OpenAiService(apiKey); + OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60)); + return openAiService; } @Bean diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java new file mode 100644 index 000000000..5053556f1 --- /dev/null +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIntegrationTests.java @@ -0,0 +1,64 @@ +package org.springframework.ai.openai.client; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.client.Generation; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.PromptTemplate; +import org.springframework.ai.prompt.SystemPromptTemplate; +import org.springframework.ai.prompt.messages.Message; +import org.springframework.ai.prompt.messages.SystemMessage; +import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest +class ClientIntegrationTests { + + @Autowired + OpenAiClient openAiClient; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Value("classpath:/prompts/system-evaluator-message.st") + private Resource systemEvaluatorResource; + + @Value("classpath:/prompts/user-evaluator-message.st") + private Resource userEvaluatorResource; + + @Test + void roleTest() { + String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; + String name = "Bob"; + String voice = "pirate"; + UserMessage userMessage = new UserMessage(request); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Generation response = openAiClient.generate(prompt).getGeneration(); + System.out.println(response); + assertThat(response).isNotNull(); + + evaluateQuestionAndAnswer(request, response.getText()); + + } + + private void evaluateQuestionAndAnswer(String question, String answer) { + PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource, + Map.of("question", question, "answer", answer)); + SystemMessage systemMessage = new SystemMessage(systemEvaluatorResource); + Message userMessage = userPromptTemplate.createMessage(); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Generation response = openAiClient.generate(prompt).getGeneration(); + System.out.println(response); + assertThat(response.getText()).isEqualTo("YES"); + } + +} diff --git a/spring-ai-openai/src/test/java/resources/bikes.json b/spring-ai-openai/src/test/resources/bikes.json similarity index 100% rename from spring-ai-openai/src/test/java/resources/bikes.json rename to spring-ai-openai/src/test/resources/bikes.json diff --git a/spring-ai-openai/src/test/resources/prompts/system-evaluator-message.st b/spring-ai-openai/src/test/resources/prompts/system-evaluator-message.st new file mode 100644 index 000000000..562703595 --- /dev/null +++ b/spring-ai-openai/src/test/resources/prompts/system-evaluator-message.st @@ -0,0 +1,3 @@ +You are an AI assistant who helps users to evaluate if the answers to questions are accurate. +You will be provided with a QUESTION and an ANSWER. +Your goal is to evaluate the QUESTION and ANSWER and reply with a YES or NO answer. \ No newline at end of file diff --git a/spring-ai-openai/src/test/resources/prompts/system-message.st b/spring-ai-openai/src/test/resources/prompts/system-message.st new file mode 100644 index 000000000..dc2cf2dcd --- /dev/null +++ b/spring-ai-openai/src/test/resources/prompts/system-message.st @@ -0,0 +1,4 @@ +"You are a helpful AI assistant. Your name is {name}. +You are an AI assistant that helps people find information. +Your name is {name} +You should reply to the user's request with your name and also in the style of a {voice}. \ No newline at end of file diff --git a/spring-ai-openai/src/test/resources/prompts/user-evaluator-message.st b/spring-ai-openai/src/test/resources/prompts/user-evaluator-message.st new file mode 100644 index 000000000..b3fa3e902 --- /dev/null +++ b/spring-ai-openai/src/test/resources/prompts/user-evaluator-message.st @@ -0,0 +1,6 @@ +The question and answer to evaluate are: + +QUESTION: ```{question}``` + +ANSWER: ```{answer}``` + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index de500e057..f2c18b94b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -20,7 +20,7 @@ import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; -import org.springframework.ai.azure.openai.llm.AzureOpenAiClient; +import org.springframework.ai.azure.openai.client.AzureOpenAiClient; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index cdafe04d3..387ad4953 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -19,7 +19,7 @@ package org.springframework.ai.autoconfigure.openai; import com.theokanning.openai.service.OpenAiService; import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient; -import org.springframework.ai.openai.llm.OpenAiClient; +import org.springframework.ai.openai.client.OpenAiClient; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.context.properties.EnableConfigurationProperties;