Reactor PromptTemplateActions
* PromptTemplateActions contains Prompt create() methods * PromptTemplateStringActions contains String render() methods * PromptTemplateMessageActions contains Message createMessage() methods * PromptTemplateChatActions contains List<Message> createMessages() actions * Message classes can accept a Spring resource in their constructors * AiClient implementations package name change llm->client * Add toString() to Generation * Add -PintegrationTest profile, disabled by default. * Add integration test for OpenAi Client and 'evaluation' * Add Question and Answer Prompts for evaluation of AiClient responses
This commit is contained in:
2
.github/workflows/continuous-integration.yml
vendored
2
.github/workflows/continuous-integration.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }}
|
||||
ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }}
|
||||
run: mvn -s settings.xml --batch-mode --update-snapshots deploy
|
||||
run: mvn -s settings.xml -Pintegration-tests --batch-mode --update-snapshots deploy
|
||||
|
||||
- name: Generate Java docs
|
||||
run: mvn javadoc:aggregate
|
||||
|
||||
46
pom.xml
46
pom.xml
@@ -131,24 +131,6 @@
|
||||
</excludes>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-failsafe-plugin</artifactId>
|
||||
<version>${maven-failsafe-plugin.version}</version>
|
||||
<configuration>
|
||||
<includes>
|
||||
<include>**/*IntegrationTests.java</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>integration-test</goal>
|
||||
<goal>verify</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
@@ -253,6 +235,34 @@
|
||||
</build>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>integration-tests</id>
|
||||
<activation>
|
||||
<activeByDefault>false</activeByDefault>
|
||||
</activation>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-failsafe-plugin</artifactId>
|
||||
<version>${maven-failsafe-plugin.version}</version>
|
||||
<configuration>
|
||||
<includes>
|
||||
<include>**/*IntegrationTests.java</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>integration-test</goal>
|
||||
<goal>verify</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-coverage</id>
|
||||
<build>
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.azure.openai.llm;
|
||||
package org.springframework.ai.azure.openai.client;
|
||||
|
||||
import com.azure.ai.openai.OpenAIClient;
|
||||
import com.azure.ai.openai.models.*;
|
||||
@@ -56,7 +56,9 @@ public class AzureOpenAiClient implements AiClient {
|
||||
options.setTemperature(this.getTemperature());
|
||||
options.setModel(this.getModel());
|
||||
|
||||
logger.trace("Azure Chat Message: ", azureChatMessage);
|
||||
ChatCompletions chatCompletions = this.msoftOpenAiClient.getChatCompletions(this.getModel(), options);
|
||||
logger.trace("Azure ChatCompletions: ", chatCompletions);
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (ChatChoice choice : chatCompletions.getChoices()) {
|
||||
if (choice.getMessage() != null && choice.getMessage().getContent() != null) {
|
||||
@@ -78,7 +80,9 @@ public class AzureOpenAiClient implements AiClient {
|
||||
ChatCompletionsOptions options = new ChatCompletionsOptions(azureMessages);
|
||||
options.setTemperature(this.getTemperature());
|
||||
options.setModel(this.getModel());
|
||||
logger.trace("Azure ChatCompletionsOptions: ", options);
|
||||
ChatCompletions chatCompletions = this.msoftOpenAiClient.getChatCompletions(this.getModel(), options);
|
||||
logger.trace("Azure ChatCompletions: ", chatCompletions);
|
||||
List<Generation> generations = new ArrayList<>();
|
||||
for (ChatChoice choice : chatCompletions.getChoices()) {
|
||||
ChatMessage choiceMessage = choice.getMessage();
|
||||
@@ -43,4 +43,9 @@ public class Generation {
|
||||
return Collections.unmodifiableMap(this.info);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Generation{" + "text='" + text + '\'' + ", info=" + info + '}';
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ import java.util.Map;
|
||||
* A PromptTemplate that lets you specify the role as a string should the current
|
||||
* implementations and their roles not suffice for your needs.
|
||||
*/
|
||||
public class ChatPromptTemplate implements PromptTemplateActions {
|
||||
public class ChatPromptTemplate implements PromptTemplateActions, PromptTemplateChatActions {
|
||||
|
||||
private final List<PromptTemplate> promptTemplates;
|
||||
|
||||
@@ -56,7 +56,7 @@ public class ChatPromptTemplate implements PromptTemplateActions {
|
||||
public List<Message> createMessages() {
|
||||
List<Message> messages = new ArrayList<>();
|
||||
for (PromptTemplate promptTemplate : promptTemplates) {
|
||||
messages.addAll(promptTemplate.createMessages());
|
||||
messages.add(promptTemplate.createMessage());
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
@@ -65,7 +65,7 @@ public class ChatPromptTemplate implements PromptTemplateActions {
|
||||
public List<Message> createMessages(Map<String, Object> model) {
|
||||
List<Message> messages = new ArrayList<>();
|
||||
for (PromptTemplate promptTemplate : promptTemplates) {
|
||||
messages.addAll(promptTemplate.createMessages(model));
|
||||
messages.add(promptTemplate.createMessage(model));
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ import java.util.Map.Entry;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
public class PromptTemplate implements PromptTemplateActions {
|
||||
public class PromptTemplate implements PromptTemplateActions, PromptTemplateStringActions, PromptTemplateMessageActions {
|
||||
|
||||
private ST st;
|
||||
|
||||
@@ -85,6 +85,25 @@ public class PromptTemplate implements PromptTemplateActions {
|
||||
}
|
||||
}
|
||||
|
||||
public PromptTemplate(Resource resource, Map<String, Object> model) {
|
||||
try (InputStream inputStream = resource.getInputStream()) {
|
||||
this.template = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new RuntimeException("Failed to read resource", ex);
|
||||
}
|
||||
// If the template string is not valid, an exception will be thrown
|
||||
try {
|
||||
this.st = new ST(this.template, '{', '}');
|
||||
for (Entry<String, Object> entry : model.entrySet()) {
|
||||
add(entry.getKey(), entry.getValue());
|
||||
}
|
||||
}
|
||||
catch (Exception ex) {
|
||||
throw new IllegalArgumentException("The template string is not valid.", ex);
|
||||
}
|
||||
}
|
||||
|
||||
public OutputParser getOutputParser() {
|
||||
return outputParser;
|
||||
}
|
||||
@@ -108,6 +127,7 @@ public class PromptTemplate implements PromptTemplateActions {
|
||||
}
|
||||
|
||||
// Render Methods
|
||||
@Override
|
||||
public String render() {
|
||||
return st.render();
|
||||
}
|
||||
@@ -138,13 +158,13 @@ public class PromptTemplate implements PromptTemplateActions {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> createMessages() {
|
||||
return List.of(new UserMessage(render()));
|
||||
public Message createMessage() {
|
||||
return new UserMessage(render());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> createMessages(Map<String, Object> model) {
|
||||
return List.of(new UserMessage(render(model)));
|
||||
public Message createMessage(Map<String, Object> model) {
|
||||
return new UserMessage(render(model));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -16,20 +16,9 @@
|
||||
|
||||
package org.springframework.ai.prompt;
|
||||
|
||||
import org.springframework.ai.prompt.messages.Message;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface PromptTemplateActions {
|
||||
|
||||
String render();
|
||||
|
||||
String render(Map<String, Object> model);
|
||||
|
||||
List<Message> createMessages();
|
||||
|
||||
List<Message> createMessages(Map<String, Object> model);
|
||||
public interface PromptTemplateActions extends PromptTemplateStringActions {
|
||||
|
||||
Prompt create();
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package org.springframework.ai.prompt;
|
||||
|
||||
import org.springframework.ai.prompt.messages.Message;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface PromptTemplateChatActions {
|
||||
|
||||
List<Message> createMessages();
|
||||
|
||||
List<Message> createMessages(Map<String, Object> model);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package org.springframework.ai.prompt;
|
||||
|
||||
import org.springframework.ai.prompt.messages.Message;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface PromptTemplateMessageActions {
|
||||
|
||||
Message createMessage();
|
||||
|
||||
Message createMessage(Map<String, Object> model);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package org.springframework.ai.prompt;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface PromptTemplateStringActions {
|
||||
|
||||
String render();
|
||||
|
||||
String render(Map<String, Object> model);
|
||||
|
||||
}
|
||||
@@ -16,15 +16,12 @@
|
||||
|
||||
package org.springframework.ai.prompt;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.springframework.ai.prompt.messages.Message;
|
||||
import org.springframework.ai.prompt.messages.SystemMessage;
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.StreamUtils;
|
||||
import org.stringtemplate.v4.ST;
|
||||
|
||||
public class SystemPromptTemplate extends PromptTemplate {
|
||||
|
||||
@@ -36,7 +33,13 @@ public class SystemPromptTemplate extends PromptTemplate {
|
||||
super(resource);
|
||||
}
|
||||
|
||||
public SystemMessage createMessage(Map<String, Object> model) {
|
||||
@Override
|
||||
public Message createMessage() {
|
||||
return new SystemMessage(render());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message createMessage(Map<String, Object> model) {
|
||||
return new SystemMessage(render(model));
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,12 @@
|
||||
|
||||
package org.springframework.ai.prompt.messages;
|
||||
|
||||
import org.springframework.core.io.Resource;
|
||||
import org.springframework.util.StreamUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -23,6 +29,9 @@ public abstract class AbstractMessage implements Message {
|
||||
|
||||
protected String content;
|
||||
|
||||
/**
|
||||
* Additional options for the message to influence the response, not a model map.
|
||||
*/
|
||||
protected Map<String, Object> properties = new HashMap<>();
|
||||
|
||||
protected MessageType messageType;
|
||||
@@ -36,10 +45,31 @@ public abstract class AbstractMessage implements Message {
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
protected AbstractMessage(MessageType messageType, String content, Map<String, Object> properties) {
|
||||
protected AbstractMessage(MessageType messageType, String content, Map<String, Object> messageProperties) {
|
||||
this.messageType = messageType;
|
||||
this.content = content;
|
||||
this.properties = properties;
|
||||
this.properties = messageProperties;
|
||||
}
|
||||
|
||||
protected AbstractMessage(MessageType messageType, Resource resource) {
|
||||
this.messageType = messageType;
|
||||
try (InputStream inputStream = resource.getInputStream()) {
|
||||
this.content = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new RuntimeException("Failed to read resource", ex);
|
||||
}
|
||||
}
|
||||
|
||||
protected AbstractMessage(MessageType messageType, Resource resource, Map<String, Object> messagePropertiets) {
|
||||
this.messageType = messageType;
|
||||
this.properties = messagePropertiets;
|
||||
try (InputStream inputStream = resource.getInputStream()) {
|
||||
this.content = StreamUtils.copyToString(inputStream, Charset.defaultCharset());
|
||||
}
|
||||
catch (IOException ex) {
|
||||
throw new RuntimeException("Failed to read resource", ex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package org.springframework.ai.prompt.messages;
|
||||
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
@@ -30,8 +32,8 @@ public class SystemMessage extends AbstractMessage {
|
||||
super(MessageType.SYSTEM, content);
|
||||
}
|
||||
|
||||
public SystemMessage(String content, Map<String, Object> properties) {
|
||||
super(MessageType.SYSTEM, content, properties);
|
||||
public SystemMessage(Resource resource) {
|
||||
super(MessageType.SYSTEM, resource);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
package org.springframework.ai.prompt.messages;
|
||||
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
@@ -29,8 +31,8 @@ public class UserMessage extends AbstractMessage {
|
||||
super(MessageType.USER, message);
|
||||
}
|
||||
|
||||
public UserMessage(String message, Map<String, Object> properties) {
|
||||
super(MessageType.USER, message, properties);
|
||||
public UserMessage(Resource resource) {
|
||||
super(MessageType.USER, resource);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -14,24 +14,24 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.springframework.ai.openai.llm;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
package org.springframework.ai.openai.client;
|
||||
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
|
||||
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
|
||||
import com.theokanning.openai.completion.chat.ChatMessage;
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.springframework.ai.client.AiClient;
|
||||
import org.springframework.ai.client.AiResponse;
|
||||
import org.springframework.ai.client.Generation;
|
||||
import org.springframework.ai.prompt.Prompt;
|
||||
|
||||
import org.springframework.ai.prompt.messages.Message;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Implementation of {@link AiClient} backed by an OpenAiService
|
||||
*/
|
||||
@@ -76,19 +76,49 @@ public class OpenAiClient implements AiClient {
|
||||
|
||||
@Override
|
||||
public AiResponse generate(Prompt prompt) {
|
||||
List<ChatCompletionRequest> chatCompletionRequests = getChatCompletionRequest(prompt);
|
||||
return getLLMResult(chatCompletionRequests);
|
||||
List<Message> messages = prompt.getMessages();
|
||||
List<ChatMessage> theoMessages = new ArrayList<>();
|
||||
for (Message message : messages) {
|
||||
String messageType = message.getMessageType().getValue();
|
||||
theoMessages.add(new ChatMessage(messageType, message.getContent()));
|
||||
}
|
||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||
.model(this.model)
|
||||
.temperature(this.temperature)
|
||||
.messages(theoMessages)
|
||||
.build();
|
||||
return getAiResponse(chatCompletionRequest);
|
||||
}
|
||||
|
||||
private ChatCompletionRequest getChatCompletionRequest(String text) {
|
||||
List<ChatMessage> chatMessages = List.of(new ChatMessage("user", text));
|
||||
logger.trace("ChatMessages: ", chatMessages);
|
||||
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
|
||||
.model(this.model)
|
||||
.temperature(this.temperature)
|
||||
.messages(List.of(new ChatMessage("user", text)))
|
||||
.build();
|
||||
logger.trace("ChatCompletionRequest: ", chatCompletionRequest);
|
||||
return chatCompletionRequest;
|
||||
}
|
||||
|
||||
private AiResponse getAiResponse(ChatCompletionRequest chatCompletionRequest) {
|
||||
List<Generation> generations = new ArrayList<>();
|
||||
logger.trace("ChatMessages: ", chatCompletionRequest.getMessages());
|
||||
List<ChatCompletionChoice> chatCompletionChoices = this.openAiService
|
||||
.createChatCompletion(chatCompletionRequest)
|
||||
.getChoices();
|
||||
logger.trace("ChatCompletionChoice: ", chatCompletionChoices);
|
||||
for (ChatCompletionChoice chatCompletionChoice : chatCompletionChoices) {
|
||||
ChatMessage chatMessage = chatCompletionChoice.getMessage();
|
||||
// TODO investigate mapping of additional metadata/runtime info to the
|
||||
// general model.
|
||||
Generation generation = new Generation(chatMessage.getContent());
|
||||
generations.add(generation);
|
||||
}
|
||||
return new AiResponse(generations);
|
||||
}
|
||||
|
||||
private String getResponse(ChatCompletionRequest chatCompletionRequest) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
this.openAiService.createChatCompletion(chatCompletionRequest).getChoices().forEach(choice -> {
|
||||
@@ -99,11 +129,6 @@ public class OpenAiClient implements AiClient {
|
||||
return response;
|
||||
}
|
||||
|
||||
private AiResponse getLLMResult(List<ChatCompletionRequest> chatCompletionRequest) {
|
||||
// TODO
|
||||
throw new RuntimeException("LLMResult getLLMResult not yet implemented");
|
||||
}
|
||||
|
||||
private List<ChatCompletionRequest> getChatCompletionRequest(Prompt prompt) {
|
||||
List<ChatCompletionRequest> chatCompletionRequests = new ArrayList<>();
|
||||
|
||||
@@ -2,26 +2,26 @@ package org.springframework.ai.openai;
|
||||
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient;
|
||||
import org.springframework.ai.openai.llm.OpenAiClient;
|
||||
import org.springframework.ai.openai.client.OpenAiClient;
|
||||
import org.springframework.boot.SpringBootConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
|
||||
@SpringBootConfiguration
|
||||
public class OpenAiTestConfiguration {
|
||||
|
||||
@Bean
|
||||
public OpenAiService theoOpenAiService() throws IOException {
|
||||
// get api token in file ~/.openai
|
||||
String apiKey = System.getenv("OPENAI_API_KEY");
|
||||
|
||||
if (!StringUtils.hasText(apiKey)) {
|
||||
throw new IllegalArgumentException(
|
||||
"You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY");
|
||||
}
|
||||
return new OpenAiService(apiKey);
|
||||
OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60));
|
||||
return openAiService;
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package org.springframework.ai.openai.client;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.client.Generation;
|
||||
import org.springframework.ai.prompt.Prompt;
|
||||
import org.springframework.ai.prompt.PromptTemplate;
|
||||
import org.springframework.ai.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.prompt.messages.Message;
|
||||
import org.springframework.ai.prompt.messages.SystemMessage;
|
||||
import org.springframework.ai.prompt.messages.UserMessage;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.core.io.Resource;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@SpringBootTest
|
||||
class ClientIntegrationTests {
|
||||
|
||||
@Autowired
|
||||
OpenAiClient openAiClient;
|
||||
|
||||
@Value("classpath:/prompts/system-message.st")
|
||||
private Resource systemResource;
|
||||
|
||||
@Value("classpath:/prompts/system-evaluator-message.st")
|
||||
private Resource systemEvaluatorResource;
|
||||
|
||||
@Value("classpath:/prompts/user-evaluator-message.st")
|
||||
private Resource userEvaluatorResource;
|
||||
|
||||
@Test
|
||||
void roleTest() {
|
||||
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
|
||||
String name = "Bob";
|
||||
String voice = "pirate";
|
||||
UserMessage userMessage = new UserMessage(request);
|
||||
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
|
||||
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
Generation response = openAiClient.generate(prompt).getGeneration();
|
||||
System.out.println(response);
|
||||
assertThat(response).isNotNull();
|
||||
|
||||
evaluateQuestionAndAnswer(request, response.getText());
|
||||
|
||||
}
|
||||
|
||||
private void evaluateQuestionAndAnswer(String question, String answer) {
|
||||
PromptTemplate userPromptTemplate = new PromptTemplate(userEvaluatorResource,
|
||||
Map.of("question", question, "answer", answer));
|
||||
SystemMessage systemMessage = new SystemMessage(systemEvaluatorResource);
|
||||
Message userMessage = userPromptTemplate.createMessage();
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
Generation response = openAiClient.generate(prompt).getGeneration();
|
||||
System.out.println(response);
|
||||
assertThat(response.getText()).isEqualTo("YES");
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
You are an AI assistant who helps users to evaluate if the answers to questions are accurate.
|
||||
You will be provided with a QUESTION and an ANSWER.
|
||||
Your goal is to evaluate the QUESTION and ANSWER and reply with a YES or NO answer.
|
||||
@@ -0,0 +1,4 @@
|
||||
"You are a helpful AI assistant. Your name is {name}.
|
||||
You are an AI assistant that helps people find information.
|
||||
Your name is {name}
|
||||
You should reply to the user's request with your name and also in the style of a {voice}.
|
||||
@@ -0,0 +1,6 @@
|
||||
The question and answer to evaluate are:
|
||||
|
||||
QUESTION: ```{question}```
|
||||
|
||||
ANSWER: ```{answer}```
|
||||
|
||||
@@ -20,7 +20,7 @@ import com.azure.ai.openai.OpenAIClient;
|
||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||
import com.azure.core.credential.AzureKeyCredential;
|
||||
|
||||
import org.springframework.ai.azure.openai.llm.AzureOpenAiClient;
|
||||
import org.springframework.ai.azure.openai.client.AzureOpenAiClient;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||
|
||||
@@ -19,7 +19,7 @@ package org.springframework.ai.autoconfigure.openai;
|
||||
import com.theokanning.openai.service.OpenAiService;
|
||||
|
||||
import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient;
|
||||
import org.springframework.ai.openai.llm.OpenAiClient;
|
||||
import org.springframework.ai.openai.client.OpenAiClient;
|
||||
import org.springframework.boot.autoconfigure.AutoConfiguration;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
|
||||
Reference in New Issue
Block a user