Reactor PromptTemplateActions

* PromptTemplateActions contains Prompt create() methods
* PromptTemplateStringActions contains String render() methods
* PromptTemplateMessageActions contains Message createMessage() methods
* PromptTemplateChatActions contains List<Message> createMessages() actions
* Message classes can accept a Spring resource in their constructors
* AiClient implementations package name change llm->client
* Add toString() to Generation
* Add -PintegrationTest profile, disabled by default.
* Add integration test for OpenAi Client and 'evaluation'
* Add Question and Answer Prompts for evaluation of AiClient responses
This commit is contained in:
Mark Pollack
2023-08-18 12:47:38 -04:00
parent 236393ef5d
commit bfafa8bde3
23 changed files with 276 additions and 71 deletions

View File

@@ -24,7 +24,7 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }}
ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }}
run: mvn -s settings.xml --batch-mode --update-snapshots deploy
run: mvn -s settings.xml -Pintegration-tests --batch-mode --update-snapshots deploy
- name: Generate Java docs
run: mvn javadoc:aggregate

46
pom.xml
View File

@@ -131,24 +131,6 @@
</excludes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
<version>${maven-failsafe-plugin.version}</version>
<configuration>
<includes>
<include>**/*IntegrationTests.java</include>
</includes>
</configuration>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
@@ -253,6 +235,34 @@
</build>
<profiles>
<profile>
<id>integration-tests</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
<version>${maven-failsafe-plugin.version}</version>
<configuration>
<includes>
<include>**/*IntegrationTests.java</include>
</includes>
</configuration>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
<profile>
<id>test-coverage</id>
<build>

View File

@@ -14,7 +14,7 @@
* limitations under the License.
*/
package org.springframework.ai.azure.openai.llm;
package org.springframework.ai.azure.openai.client;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.*;
@@ -56,7 +56,9 @@ public class AzureOpenAiClient implements AiClient {
options.setTemperature(this.getTemperature());
options.setModel(this.getModel());
logger.trace("Azure Chat Message: ", azureChatMessage);
ChatCompletions chatCompletions = this.msoftOpenAiClient.getChatCompletions(this.getModel(), options);
logger.trace("Azure ChatCompletions: ", chatCompletions);
StringBuilder sb = new StringBuilder();
for (ChatChoice choice : chatCompletions.getChoices()) {
if (choice.getMessage() != null && choice.getMessage().getContent() != null) {
@@ -78,7 +80,9 @@ public class AzureOpenAiClient implements AiClient {
ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);
options.setTemperature(this.getTemperature());
options.setModel(this.getModel());
logger.trace("Azure ChatCompletionsOptions: ", options);
ChatCompletions chatCompletions = this.msoftOpenAiClient.getChatCompletions(this.getModel(), options);
logger.trace("Azure ChatCompletions: ", chatCompletions);
List<Generation> generations = new ArrayList<>();
for (ChatChoice choice : chatCompletions.getChoices()) {
ChatMessage choiceMessage = choice.getMessage();

View File

@@ -43,4 +43,9 @@ public class Generation {
return Collections.unmodifiableMap(this.info);
}
@Override
public String toString() {
return "Generation{" + "text='" + text + '\'' + ", info=" + info + '}';
}
}

View File

@@ -26,7 +26,7 @@ import java.util.Map;
* A PromptTemplate that lets you specify the role as a string should the current
* implementations and their roles not suffice for your needs.
*/
public class ChatPromptTemplate implements PromptTemplateActions {
public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplateChatActions {
private final List<PromptTemplate> promptTemplates;
@@ -56,7 +56,7 @@ public class ChatPromptTemplate implements PromptTemplateActions {
public List<Message> createMessages() {
List<Message> messages = new ArrayList<>();
for (PromptTemplate promptTemplate : promptTemplates) {
messages.addAll(promptTemplate.createMessages());
messages.add(promptTemplate.createMessage());
}
return messages;
}
@@ -65,7 +65,7 @@ public class ChatPromptTemplate implements PromptTemplateActions {
public List<Message> createMessages(Map<String, Object> model) {
List<Message> messages = new ArrayList<>();
for (PromptTemplate promptTemplate : promptTemplates) {
messages.addAll(promptTemplate.createMessages(model));
messages.add(promptTemplate.createMessage(model));
}
return messages;
}

View File

@@ -33,7 +33,7 @@ import java.util.Map.Entry;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class PromptTemplate implements PromptTemplateActions {
public class PromptTemplate implements PromptTemplateActions, PromptTemplateStringActions, PromptTemplateMessageActions {
private ST st;
@@ -85,6 +85,25 @@ public class PromptTemplate implements PromptTemplateActions {
}
}
public PromptTemplate(Resource resource, Map<String, Object> model) {
try (InputStream inputStream = resource.getInputStream()) {
this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
}
catch (IOException ex) {
throw new RuntimeException("Failed to read resource", ex);
}
// If the template string is not valid, an exception will be thrown
try {
this.st = new ST(this.template, '{', '}');
for (Entry<String, Object> entry : model.entrySet()) {
add(entry.getKey(), entry.getValue());
}
}
catch (Exception ex) {
throw new IllegalArgumentException("The template string is not valid.", ex);
}
}
public OutputParser getOutputParser() {
return outputParser;
}
@@ -108,6 +127,7 @@ public class PromptTemplate implements PromptTemplateActions {
}
// Render Methods
@Override
public String render() {
return st.render();
}
@@ -138,13 +158,13 @@ public class PromptTemplate implements PromptTemplateActions {
}
@Override
public List<Message> createMessages() {
return List.of(new UserMessage(render()));
public Message createMessage() {
return new UserMessage(render());
}
@Override
public List<Message> createMessages(Map<String, Object> model) {
return List.of(new UserMessage(render(model)));
public Message createMessage(Map<String, Object> model) {
return new UserMessage(render(model));
}
@Override

View File

@@ -16,20 +16,9 @@
package org.springframework.ai.prompt;
import org.springframework.ai.prompt.messages.Message;
import java.util.List;
import java.util.Map;
public interface PromptTemplateActions {
String render();
String render(Map<String, Object> model);
List<Message> createMessages();
List<Message> createMessages(Map<String, Object> model);
public interface PromptTemplateActions extends PromptTemplateStringActions {
Prompt create();

View File

@@ -0,0 +1,14 @@
package org.springframework.ai.prompt;
import org.springframework.ai.prompt.messages.Message;
import java.util.List;
import java.util.Map;
public interface PromptTemplateChatActions {
List<Message> createMessages();
List<Message> createMessages(Map<String, Object> model);
}

View File

@@ -0,0 +1,13 @@
package org.springframework.ai.prompt;
import org.springframework.ai.prompt.messages.Message;
import java.util.Map;
public interface PromptTemplateMessageActions {
Message createMessage();
Message createMessage(Map<String, Object> model);
}

View File

@@ -0,0 +1,11 @@
package org.springframework.ai.prompt;
import java.util.Map;
public interface PromptTemplateStringActions {
String render();
String render(Map<String, Object> model);
}

View File

@@ -16,15 +16,12 @@
package org.springframework.ai.prompt;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.List;
import java.util.Map;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.SystemMessage;
import org.springframework.core.io.Resource;
import org.springframework.util.StreamUtils;
import org.stringtemplate.v4.ST;
public class SystemPromptTemplate extends PromptTemplate {
@@ -36,7 +33,13 @@ public class SystemPromptTemplate extends PromptTemplate {
super(resource);
}
public SystemMessage createMessage(Map<String, Object> model) {
@Override
public Message createMessage() {
return new SystemMessage(render());
}
@Override
public Message createMessage(Map<String, Object> model) {
return new SystemMessage(render(model));
}

View File

@@ -16,6 +16,12 @@
package org.springframework.ai.prompt.messages;
import org.springframework.core.io.Resource;
import org.springframework.util.StreamUtils;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.Map;
@@ -23,6 +29,9 @@ public abstract class AbstractMessage implements Message {
protected String content;
/**
* Additional options for the message to influence the response, not a model map.
*/
protected Map<String, Object> properties = new HashMap<>();
protected MessageType messageType;
@@ -36,10 +45,31 @@ public abstract class AbstractMessage implements Message {
this.content = content;
}
protected AbstractMessage(MessageType messageType, String content, Map<String, Object> properties) {
protected AbstractMessage(MessageType messageType, String content, Map<String, Object> messageProperties) {
this.messageType = messageType;
this.content = content;
this.properties = properties;
this.properties = messageProperties;
}
protected AbstractMessage(MessageType messageType, Resource resource) {
this.messageType = messageType;
try (InputStream inputStream = resource.getInputStream()) {
this.content = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
}
catch (IOException ex) {
throw new RuntimeException("Failed to read resource", ex);
}
}
protected AbstractMessage(MessageType messageType, Resource resource, Map<String, Object> messagePropertiets) {
this.messageType = messageType;
this.properties = messagePropertiets;
try (InputStream inputStream = resource.getInputStream()) {
this.content = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
}
catch (IOException ex) {
throw new RuntimeException("Failed to read resource", ex);
}
}
@Override

View File

@@ -16,6 +16,8 @@
package org.springframework.ai.prompt.messages;
import org.springframework.core.io.Resource;
import java.util.Map;
/**
@@ -30,8 +32,8 @@ public class SystemMessage extends AbstractMessage {
super(MessageType.SYSTEM, content);
}
public SystemMessage(String content, Map<String, Object> properties) {
super(MessageType.SYSTEM, content, properties);
public SystemMessage(Resource resource) {
super(MessageType.SYSTEM, resource);
}
@Override

View File

@@ -16,6 +16,8 @@
package org.springframework.ai.prompt.messages;
import org.springframework.core.io.Resource;
import java.util.Map;
/**
@@ -29,8 +31,8 @@ public class UserMessage extends AbstractMessage {
super(MessageType.USER, message);
}
public UserMessage(String message, Map<String, Object> properties) {
super(MessageType.USER, message, properties);
public UserMessage(Resource resource) {
super(MessageType.USER, resource);
}
@Override

View File

@@ -14,24 +14,24 @@
* limitations under the License.
*/
package org.springframework.ai.openai.llm;
import java.util.ArrayList;
import java.util.List;
package org.springframework.ai.openai.client;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.client.AiClient;
import org.springframework.ai.client.AiResponse;
import org.springframework.ai.client.Generation;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.util.Assert;
import java.util.ArrayList;
import java.util.List;
/**
* Implementation of {@link AiClient} backed by an OpenAiService
*/
@@ -76,19 +76,49 @@ public class OpenAiClient implements AiClient {
@Override
public AiResponse generate(Prompt prompt) {
List<ChatCompletionRequest> chatCompletionRequests = getChatCompletionRequest(prompt);
return getLLMResult(chatCompletionRequests);
List<Message> messages = prompt.getMessages();
List<ChatMessage> theoMessages = new ArrayList<>();
for (Message message : messages) {
String messageType = message.getMessageType().getValue();
theoMessages.add(new ChatMessage(messageType, message.getContent()));
}
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(this.model)
.temperature(this.temperature)
.messages(theoMessages)
.build();
return getAiResponse(chatCompletionRequest);
}
private ChatCompletionRequest getChatCompletionRequest(String text) {
List<ChatMessage> chatMessages = List.of(new ChatMessage("user", text));
logger.trace("ChatMessages: ", chatMessages);
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
.model(this.model)
.temperature(this.temperature)
.messages(List.of(new ChatMessage("user", text)))
.build();
logger.trace("ChatCompletionRequest: ", chatCompletionRequest);
return chatCompletionRequest;
}
private AiResponse getAiResponse(ChatCompletionRequest chatCompletionRequest) {
List<Generation> generations = new ArrayList<>();
logger.trace("ChatMessages: ", chatCompletionRequest.getMessages());
List<ChatCompletionChoice> chatCompletionChoices = this.openAiService
.createChatCompletion(chatCompletionRequest)
.getChoices();
logger.trace("ChatCompletionChoice: ", chatCompletionChoices);
for (ChatCompletionChoice chatCompletionChoice : chatCompletionChoices) {
ChatMessage chatMessage = chatCompletionChoice.getMessage();
// TODO investigate mapping of additional metadata/runtime info to the
// general model.
Generation generation = new Generation(chatMessage.getContent());
generations.add(generation);
}
return new AiResponse(generations);
}
private String getResponse(ChatCompletionRequest chatCompletionRequest) {
StringBuilder builder = new StringBuilder();
this.openAiService.createChatCompletion(chatCompletionRequest).getChoices().forEach(choice -> {
@@ -99,11 +129,6 @@ public class OpenAiClient implements AiClient {
return response;
}
private AiResponse getLLMResult(List<ChatCompletionRequest> chatCompletionRequest) {
// TODO
throw new RuntimeException("LLMResult getLLMResult not yet implemented");
}
private List<ChatCompletionRequest> getChatCompletionRequest(Prompt prompt) {
List<ChatCompletionRequest> chatCompletionRequests = new ArrayList<>();

View File

@@ -2,26 +2,26 @@ package org.springframework.ai.openai;
import com.theokanning.openai.service.OpenAiService;
import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient;
import org.springframework.ai.openai.llm.OpenAiClient;
import org.springframework.ai.openai.client.OpenAiClient;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.util.StringUtils;
import java.io.IOException;
import java.time.Duration;
@SpringBootConfiguration
public class OpenAiTestConfiguration {
@Bean
public OpenAiService theoOpenAiService() throws IOException {
// get api token in file ~/.openai
String apiKey = System.getenv("OPENAI_API_KEY");
if (!StringUtils.hasText(apiKey)) {
throw new IllegalArgumentException(
"You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY");
}
return new OpenAiService(apiKey);
OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60));
return openAiService;
}
@Bean

View File

@@ -0,0 +1,64 @@
package org.springframework.ai.openai.client;
import org.junit.jupiter.api.Test;
import org.springframework.ai.client.Generation;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.ai.prompt.SystemPromptTemplate;
import org.springframework.ai.prompt.messages.Message;
import org.springframework.ai.prompt.messages.SystemMessage;
import org.springframework.ai.prompt.messages.UserMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.Resource;
import java.util.List;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest
class ClientIntegrationTests {
@Autowired
OpenAiClient openAiClient;
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;
@Value("classpath:/prompts/system-evaluator-message.st")
private Resource systemEvaluatorResource;
@Value("classpath:/prompts/user-evaluator-message.st")
private Resource userEvaluatorResource;
@Test
void roleTest() {
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
String name = "Bob";
String voice = "pirate";
UserMessage userMessage = new UserMessage(request);
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Generation response = openAiClient.generate(prompt).getGeneration();
System.out.println(response);
assertThat(response).isNotNull();
evaluateQuestionAndAnswer(request, response.getText());
}
private void evaluateQuestionAndAnswer(String question, String answer) {
PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource,
Map.of("question", question, "answer", answer));
SystemMessage systemMessage = new SystemMessage(systemEvaluatorResource);
Message userMessage = userPromptTemplate.createMessage();
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
Generation response = openAiClient.generate(prompt).getGeneration();
System.out.println(response);
assertThat(response.getText()).isEqualTo("YES");
}
}

View File

@@ -0,0 +1,3 @@
You are an AI assistant who helps users to evaluate if the answers to questions are accurate.
You will be provided with a QUESTION and an ANSWER.
Your goal is to evaluate the QUESTION and ANSWER and reply with a YES or NO answer.

View File

@@ -0,0 +1,4 @@
"You are a helpful AI assistant. Your name is {name}.
You are an AI assistant that helps people find information.
Your name is {name}
You should reply to the user's request with your name and also in the style of a {voice}.

View File

@@ -0,0 +1,6 @@
The question and answer to evaluate are:
QUESTION: ```{question}```
ANSWER: ```{answer}```

View File

@@ -20,7 +20,7 @@ import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import org.springframework.ai.azure.openai.llm.AzureOpenAiClient;
import org.springframework.ai.azure.openai.client.AzureOpenAiClient;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;

View File

@@ -19,7 +19,7 @@ package org.springframework.ai.autoconfigure.openai;
import com.theokanning.openai.service.OpenAiService;
import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient;
import org.springframework.ai.openai.llm.OpenAiClient;
import org.springframework.ai.openai.client.OpenAiClient;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.context.properties.EnableConfigurationProperties;