change Node to Content and move into model package.

This commit is contained in:
Mark Pollack
2024-04-19 12:24:38 -04:00
parent ed48662cd9
commit 7148cd10a4
9 changed files with 32 additions and 41 deletions

View File

@@ -15,10 +15,7 @@
*/
package org.springframework.ai.chat.messages;
import org.springframework.ai.node.Node;
import java.util.List;
import java.util.Map;
import org.springframework.ai.model.Content;
/**
* The Message interface represents a message that can be sent or received in a chat
@@ -28,9 +25,7 @@ import java.util.Map;
* @see Media
* @see MessageType
*/
public interface Message extends Node<String> {
String getContent();
public interface Message extends Content {
MessageType getMessageType();

View File

@@ -1,7 +1,7 @@
package org.springframework.ai.chat.prompt.transformer;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.node.Node;
import org.springframework.ai.model.Content;
import java.util.*;
@@ -16,7 +16,7 @@ public class PromptContext {
private Prompt prompt; // The most up-to-date prompt to use
private List<Node<?>> nodes; // The most up-to-date data to use
private List<Content> contents; // The most up-to-date data to use
private List<Prompt> promptHistory;
@@ -33,11 +33,11 @@ public class PromptContext {
this.conversationId = conversationId;
}
public PromptContext(Prompt prompt, List<Node<?>> nodes) {
public PromptContext(Prompt prompt, List<Content> contents) {
this.prompt = prompt;
this.promptHistory = new ArrayList<>();
this.promptHistory.add(prompt);
this.nodes = nodes;
this.contents = contents;
}
public Prompt getPrompt() {
@@ -48,16 +48,16 @@ public class PromptContext {
this.prompt = prompt;
}
public void addData(Node<?> datum) {
this.nodes.add(datum);
public void addData(Content datum) {
this.contents.add(datum);
}
public List<Node<?>> getNodes() {
return nodes;
public List<Content> getNodes() {
return contents;
}
public void setNodes(List<Node<?>> nodes) {
this.nodes = nodes;
public void setNodes(List<Content> contents) {
this.contents = contents;
}
public void addPromptHistory(Prompt prompt) {
@@ -78,7 +78,7 @@ public class PromptContext {
@Override
public String toString() {
return "PromptContext{" + "prompt=" + prompt + ", nodes=" + nodes + ", promptHistory=" + promptHistory
return "PromptContext{" + "prompt=" + prompt + ", contents=" + contents + ", promptHistory=" + promptHistory
+ ", conversationId='" + conversationId + '\'' + ", metadata=" + metadata + '}';
}
@@ -88,14 +88,14 @@ public class PromptContext {
return true;
if (!(o instanceof PromptContext that))
return false;
return Objects.equals(prompt, that.prompt) && Objects.equals(nodes, that.nodes)
return Objects.equals(prompt, that.prompt) && Objects.equals(contents, that.contents)
&& Objects.equals(promptHistory, that.promptHistory)
&& Objects.equals(conversationId, that.conversationId) && Objects.equals(metadata, that.metadata);
}
@Override
public int hashCode() {
return Objects.hash(prompt, nodes, promptHistory, conversationId, metadata);
return Objects.hash(prompt, contents, promptHistory, conversationId, metadata);
}
}

View File

@@ -6,7 +6,7 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.node.Node;
import org.springframework.ai.node.Content;
import java.util.List;
import java.util.Map;
@@ -44,11 +44,11 @@ public class QuestionContextAugmentor implements PromptTransformer {
return promptContext;
}
protected String doCreateContext(List<Node<?>> data) {
protected String doCreateContext(List<Content> data) {
return data.stream()
.filter(node -> node instanceof Document)
.map(node -> (Document) node)
.map(Node::getContent)
.map(Content::getContent)
.collect(Collectors.joining(System.lineSeparator()));
}

View File

@@ -28,7 +28,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import org.springframework.ai.chat.messages.Media;
import org.springframework.ai.document.id.IdGenerator;
import org.springframework.ai.document.id.RandomIdGenerator;
import org.springframework.ai.node.Node;
import org.springframework.ai.model.Content;
import org.springframework.util.Assert;
/**
@@ -36,7 +36,7 @@ import org.springframework.util.Assert;
* the document's unique ID and an optional embedding.
*/
@JsonIgnoreProperties({ "contentFormatter" })
public class Document implements Node<String> {
public class Document implements Content {
public final static ContentFormatter DEFAULT_CONTENT_FORMATTER = DefaultContentFormatter.defaultConfig();

View File

@@ -3,7 +3,7 @@ package org.springframework.ai.evaluation;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.agent.AgentResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.node.Node;
import org.springframework.ai.model.Content;
import java.util.List;
import java.util.Objects;
@@ -12,7 +12,7 @@ public class EvaluationRequest {
private final Prompt prompt;
private final List<Node<?>> dataList;
private final List<Content> dataList;
private final ChatResponse chatResponse;
@@ -21,7 +21,7 @@ public class EvaluationRequest {
agentResponse.getChatResponse());
}
public EvaluationRequest(Prompt prompt, List<Node<?>> dataList, ChatResponse chatResponse) {
public EvaluationRequest(Prompt prompt, List<Content> dataList, ChatResponse chatResponse) {
this.prompt = prompt;
this.dataList = dataList;
this.chatResponse = chatResponse;
@@ -31,7 +31,7 @@ public class EvaluationRequest {
return prompt;
}
public List<Node<?>> getDataList() {
public List<Content> getDataList() {
return dataList;
}

View File

@@ -6,7 +6,7 @@ import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.node.Node;
import org.springframework.ai.model.Content;
import java.util.Collections;
import java.util.List;
@@ -61,11 +61,11 @@ public class RelevancyEvaluator implements Evaluator {
}
protected String doGetSupportingData(EvaluationRequest evaluationRequest) {
List<Node<?>> data = evaluationRequest.getDataList();
List<Content> data = evaluationRequest.getDataList();
String supportingData = data.stream()
.filter(node -> node != null && node.getContent() instanceof String)
.map(node -> (Node<String>) node)
.map(Node::getContent)
.map(node -> (Content) node)
.map(Content::getContent)
.collect(Collectors.joining("\n"));
return supportingData;
}

View File

@@ -1,4 +1,4 @@
package org.springframework.ai.node;
package org.springframework.ai.model;
import org.springframework.ai.chat.messages.Media;
@@ -12,9 +12,9 @@ import java.util.Map;
* @author Mark Pollack
* @since 1.0 M1
*/
public interface Node<T> {
public interface Content {
T getContent();
String getContent();
List<Media> getMedia();

View File

@@ -1,4 +0,0 @@
/**
* This package contains base data types used across Spring AI.
*/
package org.springframework.ai.node;

View File

@@ -177,7 +177,7 @@ public class Neo4jVectorStore implements VectorStore, InitializingBean {
*/
public Builder withLabel(String newLabel) {
Assert.hasText(newLabel, "Node label may not be null or blank");
Assert.hasText(newLabel, "Content label may not be null or blank");
this.label = newLabel;
return this;